Skip to content

Commit 4145f6b

Browse files
committed
addressed PR comments
1 parent 7f3cbc5 commit 4145f6b

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/diffusers/models/controlnets/controlnet_sana.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ def forward(
246246

247247
# 2. Transformer blocks
248248
block_res_samples = ()
249-
for block in self.transformer_blocks:
250-
if torch.is_grad_enabled() and self.gradient_checkpointing:
249+
if torch.is_grad_enabled() and self.gradient_checkpointing:
250+
for block in self.transformer_blocks:
251251
hidden_states = self._gradient_checkpointing_func(
252252
block,
253253
hidden_states,
@@ -258,7 +258,9 @@ def forward(
258258
post_patch_height,
259259
post_patch_width,
260260
)
261-
else:
261+
block_res_samples = block_res_samples + (hidden_states,)
262+
else:
263+
for block in self.transformer_blocks:
262264
hidden_states = block(
263265
hidden_states,
264266
attention_mask,
@@ -268,7 +270,7 @@ def forward(
268270
post_patch_height,
269271
post_patch_width,
270272
)
271-
block_res_samples = block_res_samples + (hidden_states,)
273+
block_res_samples = block_res_samples + (hidden_states,)
272274

273275
# 3. ControlNet blocks
274276
controlnet_block_res_samples = ()

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ def forward(
434434
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
435435

436436
# 2. Transformer blocks
437-
for index_block, block in enumerate(self.transformer_blocks):
438-
if torch.is_grad_enabled() and self.gradient_checkpointing:
437+
if torch.is_grad_enabled() and self.gradient_checkpointing:
438+
for index_block, block in enumerate(self.transformer_blocks):
439439
hidden_states = self._gradient_checkpointing_func(
440440
block,
441441
hidden_states,
@@ -446,8 +446,11 @@ def forward(
446446
post_patch_height,
447447
post_patch_width,
448448
)
449+
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
450+
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
449451

450-
else:
452+
else:
453+
for index_block, block in enumerate(self.transformer_blocks):
451454
hidden_states = block(
452455
hidden_states,
453456
attention_mask,
@@ -457,8 +460,8 @@ def forward(
457460
post_patch_height,
458461
post_patch_width,
459462
)
460-
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
461-
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
463+
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
464+
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
462465

463466
# 3. Normalization
464467
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)

0 commit comments

Comments
 (0)