Skip to content

Commit a41e4c5

Browse files
authored
[Chore] add disable forward chunking to SD3 transformer. (#8838)
add disable forward chunking to SD3 transformer.
1 parent 12625c1 commit a41e4c5

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int
139139
for module in self.children():
140140
fn_recursive_feed_forward(module, chunk_size, dim)
141141

142+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
143+
def disable_forward_chunking(self):
144+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
145+
if hasattr(module, "set_chunk_feed_forward"):
146+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
147+
148+
for child in module.children():
149+
fn_recursive_feed_forward(child, chunk_size, dim)
150+
151+
for module in self.children():
152+
fn_recursive_feed_forward(module, None, 0)
153+
142154
@property
143155
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
144156
def attn_processors(self) -> Dict[str, AttentionProcessor]:

0 commit comments

Comments
 (0)