Skip to content

Commit 771344d

Browse files
charliewwdevclaude
andcommitted
fix MPS GPU support and defensive API checks
- _apply_offloading: skip cpu offload on MPS, move directly to GPU - _apply_vae_opts: add hasattr checks (WanPipeline lacks vae_slicing) - All backends pass device to _apply_offloading for proper routing - E2e test: MPS uses offload=none to run on GPU directly Tested: Wan 1.3B on MPS generates 17 frames in 294s (vs 414s on CPU) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a7e99ae commit 771344d

File tree

6 files changed

+28
-13
lines changed

6 files changed

+28
-13
lines changed

animatediff/backends/cogvideo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def load(
6767
instance = cls(pipe, model_variant=model_variant)
6868

6969
if offload_strategy != "none":
70-
instance._apply_offloading(pipe, offload_strategy)
70+
instance._apply_offloading(pipe, offload_strategy, device=device)
7171
else:
7272
pipe.to(device)
7373

animatediff/backends/hunyuan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def load(
6464
instance = cls(pipe, model_variant=model_variant)
6565

6666
if offload_strategy != "none":
67-
instance._apply_offloading(pipe, offload_strategy)
67+
instance._apply_offloading(pipe, offload_strategy, device=device)
6868
else:
6969
pipe.to(device)
7070

animatediff/backends/ltx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def load(
6262
instance = cls(pipe, model_variant=model_variant)
6363

6464
if offload_strategy != "none":
65-
instance._apply_offloading(pipe, offload_strategy)
65+
instance._apply_offloading(pipe, offload_strategy, device=device)
6666
else:
6767
pipe.to(device)
6868

animatediff/backends/wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def load(
7777

7878
# Apply offloading (must be before .to(device) for cpu offload)
7979
if offload_strategy != "none":
80-
instance._apply_offloading(pipe, offload_strategy)
80+
instance._apply_offloading(pipe, offload_strategy, device=device)
8181
else:
8282
pipe.to(device)
8383

animatediff/core/base_pipeline.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,29 @@ def _make_generator(self, seed: int, device: str) -> Optional[torch.Generator]:
8282
return torch.Generator(device=device).manual_seed(seed)
8383
return None
8484

85-
def _apply_offloading(self, pipe, strategy: str):
86-
"""Apply memory offloading strategy to a diffusers pipeline."""
87-
if strategy == "model_cpu":
85+
def _apply_offloading(self, pipe, strategy: str, device: str = "cuda"):
86+
"""Apply memory offloading strategy to a diffusers pipeline.
87+
88+
Note: CPU offloading only works with CUDA. For MPS, we skip offloading
89+
and move the full pipeline to the device instead.
90+
"""
91+
# CPU offloading requires CUDA — skip for MPS/CPU and just move to device
92+
if device != "cuda" and device != "cpu":
93+
logger.info(f"Offloading not supported on {device}, moving pipeline to {device}")
94+
pipe.to(device)
95+
return
96+
97+
if strategy == "model_cpu" and hasattr(pipe, "enable_model_cpu_offload"):
8898
pipe.enable_model_cpu_offload()
89-
elif strategy == "sequential_cpu":
99+
elif strategy == "sequential_cpu" and hasattr(pipe, "enable_sequential_cpu_offload"):
90100
pipe.enable_sequential_cpu_offload()
101+
else:
102+
logger.warning(f"Offload strategy '{strategy}' not available, moving to {device}")
103+
pipe.to(device)
91104

92105
def _apply_vae_opts(self, pipe, slicing: bool = True, tiling: bool = False):
93106
"""Apply VAE memory optimizations."""
94-
if slicing:
107+
if slicing and hasattr(pipe, "enable_vae_slicing"):
95108
pipe.enable_vae_slicing()
96109
if tiling and hasattr(pipe, "enable_vae_tiling"):
97110
pipe.enable_vae_tiling()

tests/test_e2e_generate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,28 @@ def test_wan_1_3b_generate():
2727
if torch.cuda.is_available():
2828
device = "cuda"
2929
dtype = torch.float16
30+
offload = "model_cpu" # CUDA supports CPU offload
3031
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
3132
device = "mps"
32-
dtype = torch.float32 # MPS works best with float32 for this model
33+
dtype = torch.float32 # MPS requires float32 for Wan
34+
offload = "none" # MPS: load directly to GPU (no cpu offload support)
3335
else:
3436
print("SKIP: No GPU available (need CUDA or MPS)")
3537
return
3638

3739
print(f"\n{'='*60}")
38-
print(f"Device: {device} | dtype: {dtype}")
40+
print(f"Device: {device} | dtype: {dtype} | offload: {offload}")
3941
print(f"{'='*60}")
4042

4143
# Load model (will download ~5GB on first run)
4244
print("\n[1/3] Loading Wan 2.1 1.3B...")
4345
t0 = time.time()
4446
backend = WanBackend.load(
45-
model_path=None, # auto: Wan-AI/Wan2.1-T2V-1.3B
47+
model_path=None, # auto: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
4648
torch_dtype=dtype,
4749
device=device,
4850
quantization="none",
49-
offload_strategy="model_cpu", # save memory
51+
offload_strategy=offload,
5052
enable_vae_slicing=True,
5153
model_variant="1.3B",
5254
)

0 commit comments

Comments
 (0)