@@ -32,7 +32,7 @@ def save_checkpoint_optimizer(model, optimizer, rank, output_dir, step, discrimi
3232 save_dir = os .path .join (output_dir , f"checkpoint-{ step } " )
3333 os .makedirs (save_dir , exist_ok = True )
3434 # save using safetensors
35- if rank < = 0 and not discriminator :
35+ if rank = = 0 and not discriminator :
3636 weight_path = os .path .join (save_dir , "diffusion_pytorch_model.safetensors" )
3737 save_file (cpu_state , weight_path )
3838 config_dict = dict (model .config )
@@ -60,7 +60,7 @@ def save_checkpoint(transformer, rank, output_dir, step):
6060 ):
6161 cpu_state = transformer .state_dict ()
6262 # todo move to get_state_dict
63- if rank < = 0 :
63+ if rank = = 0 :
6464 save_dir = os .path .join (output_dir , f"checkpoint-{ step } " )
6565 os .makedirs (save_dir , exist_ok = True )
6666 # save using safetensors
@@ -98,7 +98,7 @@ def save_checkpoint_generator_discriminator(
9898 hf_weight_dir = os .path .join (save_dir , "hf_weights" )
9999 os .makedirs (hf_weight_dir , exist_ok = True )
100100 # save using safetensors
101- if rank < = 0 :
101+ if rank = = 0 :
102102 config_dict = dict (model .config )
103103 config_path = os .path .join (hf_weight_dir , "config.json" )
104104 # save dict as json
@@ -139,7 +139,7 @@ def save_checkpoint_generator_discriminator(
139139 optim_state = FSDP .optim_state_dict (discriminator , discriminator_optimizer )
140140 model_state = discriminator .state_dict ()
141141 state_dict = {"optimizer" : optim_state , "model" : model_state }
142- if rank < = 0 :
142+ if rank = = 0 :
143143 discriminator_fsdp_state_fil = os .path .join (discriminator_fsdp_state_dir , "discriminator_state.pt" )
144144 torch .save (state_dict , discriminator_fsdp_state_fil )
145145
@@ -178,7 +178,7 @@ def load_full_state_model(model, optimizer, checkpoint_file, rank):
178178 ):
179179 discriminator_state = torch .load (checkpoint_file )
180180 model_state = discriminator_state ["model" ]
181- if rank < = 0 :
181+ if rank = = 0 :
182182 optim_state = discriminator_state ["optimizer" ]
183183 else :
184184 optim_state = None
@@ -241,7 +241,7 @@ def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, pipelin
241241 optimizer ,
242242 )
243243
244- if rank < = 0 :
244+ if rank = = 0 :
245245 save_dir = os .path .join (output_dir , f"lora-checkpoint-{ step } " )
246246 os .makedirs (save_dir , exist_ok = True )
247247
0 commit comments