@@ -341,6 +341,7 @@ def forward(
341341        block_controlnet_hidden_states : List  =  None ,
342342        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
343343        return_dict : bool  =  True ,
344+         skip_layers : Optional [List [int ]] =  None ,
344345    ) ->  Union [torch .FloatTensor , Transformer2DModelOutput ]:
345346        """ 
346347        The [`SD3Transformer2DModel`] forward method. 
@@ -363,6 +364,8 @@ def forward(
363364            return_dict (`bool`, *optional*, defaults to `True`): 
364365                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 
365366                tuple. 
367+             skip_layers (`list` of `int`, *optional*): 
368+                 A list of layer indices to skip during the forward pass. 
366369
367370        Returns: 
368371            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 
@@ -390,7 +393,10 @@ def forward(
390393        encoder_hidden_states  =  self .context_embedder (encoder_hidden_states )
391394
392395        for  index_block , block  in  enumerate (self .transformer_blocks ):
393-             if  self .training  and  self .gradient_checkpointing :
396+             # Skip specified layers 
397+             is_skip  =  True  if  skip_layers  is  not None  and  index_block  in  skip_layers  else  False 
398+ 
399+             if  torch .is_grad_enabled () and  self .gradient_checkpointing  and  not  is_skip :
394400
395401                def  create_custom_forward (module , return_dict = None ):
396402                    def  custom_forward (* inputs ):
@@ -410,8 +416,7 @@ def custom_forward(*inputs):
410416                    joint_attention_kwargs ,
411417                    ** ckpt_kwargs ,
412418                )
413- 
414-             else :
419+             elif  not  is_skip :
415420                encoder_hidden_states , hidden_states  =  block (
416421                    hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb ,
417422                    joint_attention_kwargs = joint_attention_kwargs ,
0 commit comments