File tree Expand file tree Collapse file tree 2 files changed +14
-9
lines changed Expand file tree Collapse file tree 2 files changed +14
-9
lines changed Original file line number Diff line number Diff line change @@ -246,8 +246,8 @@ def forward(
246246
247247 # 2. Transformer blocks
248248 block_res_samples = ()
249- for block in self .transformer_blocks :
250- if torch . is_grad_enabled () and self .gradient_checkpointing :
249+ if torch . is_grad_enabled () and self .gradient_checkpointing :
250+ for block in self .transformer_blocks :
251251 hidden_states = self ._gradient_checkpointing_func (
252252 block ,
253253 hidden_states ,
@@ -258,7 +258,9 @@ def forward(
258258 post_patch_height ,
259259 post_patch_width ,
260260 )
261- else :
261+ block_res_samples = block_res_samples + (hidden_states ,)
262+ else :
263+ for block in self .transformer_blocks :
262264 hidden_states = block (
263265 hidden_states ,
264266 attention_mask ,
@@ -268,7 +270,7 @@ def forward(
268270 post_patch_height ,
269271 post_patch_width ,
270272 )
271- block_res_samples = block_res_samples + (hidden_states ,)
273+ block_res_samples = block_res_samples + (hidden_states ,)
272274
273275 # 3. ControlNet blocks
274276 controlnet_block_res_samples = ()
Original file line number Diff line number Diff line change @@ -434,8 +434,8 @@ def forward(
434434 encoder_hidden_states = self .caption_norm (encoder_hidden_states )
435435
436436 # 2. Transformer blocks
437- for index_block , block in enumerate ( self .transformer_blocks ) :
438- if torch . is_grad_enabled () and self .gradient_checkpointing :
437+ if torch . is_grad_enabled () and self .gradient_checkpointing :
438+ for index_block , block in enumerate ( self .transformer_blocks ) :
439439 hidden_states = self ._gradient_checkpointing_func (
440440 block ,
441441 hidden_states ,
@@ -446,8 +446,11 @@ def forward(
446446 post_patch_height ,
447447 post_patch_width ,
448448 )
449+ if controlnet_block_samples is not None and 0 < index_block <= len (controlnet_block_samples ):
450+ hidden_states = hidden_states + controlnet_block_samples [index_block - 1 ]
449451
450- else :
452+ else :
453+ for index_block , block in enumerate (self .transformer_blocks ):
451454 hidden_states = block (
452455 hidden_states ,
453456 attention_mask ,
@@ -457,8 +460,8 @@ def forward(
457460 post_patch_height ,
458461 post_patch_width ,
459462 )
460- if controlnet_block_samples is not None and 0 < index_block <= len (controlnet_block_samples ):
461- hidden_states = hidden_states + controlnet_block_samples [index_block - 1 ]
463+ if controlnet_block_samples is not None and 0 < index_block <= len (controlnet_block_samples ):
464+ hidden_states = hidden_states + controlnet_block_samples [index_block - 1 ]
462465
463466 # 3. Normalization
464467 hidden_states = self .norm_out (hidden_states , embedded_timestep , self .scale_shift_table )
You can’t perform that action at this time.
0 commit comments