Skip to content

Commit 52b9b61

Browse files
committed
use unwrap_module for torch compiled modules
1 parent 2238f55 commit 52b9b61

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
scale_lora_layers,
3131
unscale_lora_layers,
3232
)
33-
from ...utils.torch_utils import is_compiled_module, randn_tensor
33+
from ...utils.torch_utils import randn_tensor, unwrap_module
3434
from ..controlnet.multicontrolnet import MultiControlNetModel
3535
from ..modular_pipeline import (
3636
AutoPipelineBlocks,
@@ -2545,7 +2545,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
25452545
data.height = data.height * pipeline.vae_scale_factor
25462546
data.width = data.width * pipeline.vae_scale_factor
25472547

2548-
controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
2548+
controlnet = unwrap_module(pipeline.controlnet)
25492549

25502550
# (1.1)
25512551
# control_guidance_start/control_guidance_end (align format)
@@ -2973,7 +2973,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
29732973
data.height = data.height * pipeline.vae_scale_factor
29742974
data.width = data.width * pipeline.vae_scale_factor
29752975

2976-
controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
2976+
controlnet = unwrap_module(pipeline.controlnet)
29772977

29782978
# (1.1)
29792979
# control guidance

0 commit comments

Comments
 (0)