8686 "User Prompt: " ,
8787]
8888
89+ class SanaVanillaAttnProcessor :
90+ r"""
91+ Processor for implementing scaled dot-product attention to support JVP calculation during training.
92+ """
93+
94+ def __init__ (self ):
95+ pass
96+
97+ @staticmethod
98+ def scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = None
99+ ) -> torch .Tensor :
100+ B , H , L , S = * query .size ()[:- 1 ], key .size (- 2 )
101+ scale_factor = 1 / math .sqrt (query .size (- 1 )) if scale is None else scale
102+ attn_bias = torch .zeros (B , H , L , S , dtype = query .dtype , device = query .device )
103+
104+ if attn_mask is not None :
105+ if attn_mask .dtype == torch .bool :
106+ attn_bias .masked_fill_ (attn_mask .logical_not (), float ("-inf" ))
107+ else :
108+ attn_bias += attn_mask
109+ attn_weight = query @ key .transpose (- 2 , - 1 ) * scale_factor
110+ attn_weight += attn_bias
111+ attn_weight = torch .softmax (attn_weight , dim = - 1 )
112+ attn_weight = torch .dropout (attn_weight , dropout_p , train = True )
113+ return attn_weight @ value
114+
115+ def __call__ (
116+ self ,
117+ attn : Attention ,
118+ hidden_states : torch .Tensor ,
119+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
120+ attention_mask : Optional [torch .Tensor ] = None ,
121+ ) -> torch .Tensor :
122+ batch_size , sequence_length , _ = (
123+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
124+ )
125+
126+ if attention_mask is not None :
127+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
128+ # scaled_dot_product_attention expects attention_mask shape to be
129+ # (batch, heads, source_length, target_length)
130+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
131+
132+ query = attn .to_q (hidden_states )
133+
134+ if encoder_hidden_states is None :
135+ encoder_hidden_states = hidden_states
136+
137+ key = attn .to_k (encoder_hidden_states )
138+ value = attn .to_v (encoder_hidden_states )
139+
140+ if attn .norm_q is not None :
141+ query = attn .norm_q (query )
142+ if attn .norm_k is not None :
143+ key = attn .norm_k (key )
144+
145+ inner_dim = key .shape [- 1 ]
146+ head_dim = inner_dim // attn .heads
147+
148+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
149+
150+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
151+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
152+
153+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
154+ hidden_states = self .scaled_dot_product_attention (
155+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
156+ )
157+
158+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
159+ hidden_states = hidden_states .to (query .dtype )
160+
161+ # linear proj
162+ hidden_states = attn .to_out [0 ](hidden_states )
163+ # dropout
164+ hidden_states = attn .to_out [1 ](hidden_states )
165+
166+ hidden_states = hidden_states / attn .rescale_output_factor
167+
168+ return hidden_states
89169
90170
91171class Text2ImageDataset (Dataset ):
@@ -109,7 +189,6 @@ def __init__(self, hf_dataset, resolution=1024):
109189 T .Lambda (lambda img : img .convert ("RGB" )),
110190 T .Resize (resolution ), # Image.BICUBIC
111191 T .CenterCrop (resolution ),
112- # T.RandomHorizontalFlip(),
113192 T .ToTensor (),
114193 T .Normalize ([0.5 ], [0.5 ]),
115194 ])
@@ -132,7 +211,7 @@ def __getitem__(self, idx):
132211 'image' : image_tensor
133212 }
134213
135- # TODO here
214+
136215def save_model_card (
137216 repo_id : str ,
138217 images = None ,
@@ -807,7 +886,6 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
807886 return (trigflow_model_out ,)
808887
809888
810-
811889def compute_density_for_timestep_sampling_scm (
812890 batch_size : int , logit_mean : float = None , logit_std : float = None
813891):
@@ -820,7 +898,6 @@ def compute_density_for_timestep_sampling_scm(
820898 return u
821899
822900
823-
824901def main (args ):
825902 if args .report_to == "wandb" and args .hub_token is not None :
826903 raise ValueError (
@@ -872,7 +949,6 @@ def main(args):
872949 if args .seed is not None :
873950 set_seed (args .seed )
874951
875-
876952 # Handle the repository creation
877953 if accelerator .is_main_process :
878954 if args .output_dir is not None :
@@ -904,8 +980,9 @@ def main(args):
904980
905981 ori_transformer = SanaTransformer2DModel .from_pretrained (
906982 args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant ,
907- guidance_embeds = True , cross_attention_type = 'vanilla'
983+ guidance_embeds = True ,
908984 )
985+ ori_transformer .set_attn_processor (SanaVanillaAttnProcessor ())
909986
910987 ori_transformer_no_guide = SanaTransformer2DModel .from_pretrained (
911988 args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant ,
@@ -929,7 +1006,6 @@ def main(args):
9291006
9301007 zero_state_dict = {}
9311008
932-
9331009 target_device = accelerator .device
9341010 param_w1 = guidance_embedder_module .linear_1 .weight
9351011 zero_state_dict ['linear_1.weight' ] = torch .zeros (param_w1 .shape , device = target_device )
@@ -941,7 +1017,6 @@ def main(args):
9411017 zero_state_dict ['linear_2.bias' ] = torch .zeros (param_b2 .shape , device = target_device )
9421018 guidance_embedder_module .load_state_dict (zero_state_dict , strict = False , assign = True )
9431019
944-
9451020 transformer = SanaTrigFlow (ori_transformer , guidance = True ).train ()
9461021 pretrained_model = SanaTrigFlow (ori_transformer_no_guide , guidance = False ).eval ()
9471022
@@ -951,7 +1026,6 @@ def main(args):
9511026 head_block_ids = args .head_block_ids ,
9521027 ).train ()
9531028
954-
9551029 transformer .requires_grad_ (True )
9561030 pretrained_model .requires_grad_ (False )
9571031 disc .model .requires_grad_ (False )
@@ -1005,7 +1079,6 @@ def main(args):
10051079 if args .gradient_checkpointing :
10061080 transformer .enable_gradient_checkpointing ()
10071081
1008-
10091082 def unwrap_model (model ):
10101083 model = accelerator .unwrap_model (model )
10111084 model = model ._orig_mod if is_compiled_module (model ) else model
@@ -1063,7 +1136,6 @@ def load_model_hook(models, input_dir):
10631136 accelerator .register_save_state_pre_hook (save_model_hook )
10641137 accelerator .register_load_state_pre_hook (load_model_hook )
10651138
1066-
10671139 # Enable TF32 for faster training on Ampere GPUs,
10681140 # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
10691141 if args .allow_tf32 and torch .cuda .is_available ():
@@ -1087,7 +1159,6 @@ def load_model_hook(models, input_dir):
10871159 else :
10881160 optimizer_class = torch .optim .AdamW
10891161
1090-
10911162 # Optimization parameters
10921163 optimizer_G = optimizer_class (
10931164 transformer .parameters (),
@@ -1391,12 +1462,10 @@ def model_wrapper(scaled_x_t, t):
13911462 z_D = torch .randn_like (model_input ) * sigma_data
13921463 noised_predicted_x0 = torch .cos (t_D ) * pred_x_0 + torch .sin (t_D ) * z_D
13931464
1394-
13951465 # Calculate adversarial loss
13961466 pred_fake = disc (hidden_states = (noised_predicted_x0 / sigma_data ), timestep = t_D .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask )
13971467 adv_loss = - torch .mean (pred_fake )
13981468
1399-
14001469 # Total loss = sCM loss + LADD loss
14011470
14021471 total_loss = args .scm_lambda * loss + adv_loss * args .adv_lambda
@@ -1405,8 +1474,6 @@ def model_wrapper(scaled_x_t, t):
14051474
14061475 accelerator .backward (total_loss )
14071476
1408-
1409-
14101477 if accelerator .sync_gradients :
14111478 grad_norm = accelerator .clip_grad_norm_ (transformer .parameters (), args .gradient_clip )
14121479 if torch .logical_or (grad_norm .isnan (), grad_norm .isinf ()):
@@ -1504,7 +1571,6 @@ def model_wrapper(scaled_x_t, t):
15041571
15051572 accelerator .backward (loss_D )
15061573
1507-
15081574 if accelerator .sync_gradients :
15091575 grad_norm = accelerator .clip_grad_norm_ (disc .parameters (), args .gradient_clip )
15101576 if torch .logical_or (grad_norm .isnan (), grad_norm .isinf ()):
@@ -1519,7 +1585,6 @@ def model_wrapper(scaled_x_t, t):
15191585 optimizer_D .step ()
15201586 optimizer_D .zero_grad (set_to_none = True )
15211587
1522-
15231588 # Checks if the accelerator has performed an optimization step behind the scenes
15241589 if accelerator .sync_gradients :
15251590 progress_bar .update (1 )
@@ -1584,7 +1649,6 @@ def model_wrapper(scaled_x_t, t):
15841649 images = None
15851650 del pipeline
15861651
1587-
15881652 accelerator .wait_for_everyone ()
15891653 if accelerator .is_main_process :
15901654 transformer = unwrap_model (transformer )
0 commit comments