1515# See the License for the specific language governing permissions and
1616
1717import argparse
18- import copy
19- import itertools
20- import json
18+ import io
2119import logging
2220import math
2321import os
24- import random
2522import shutil
26- import warnings
2723from pathlib import Path
24+ from typing import Callable
2825
2926import accelerate
30- import io
3127import numpy as np
3228import torch
3329import torch .nn as nn
3430import torch .nn .functional as F
3531import torch .utils .checkpoint
3632import torchvision .transforms as T
37- import torchvision .transforms .functional as TF
3833import transformers
39- import webdataset as wds
4034from accelerate import Accelerator
4135from accelerate .logging import get_logger
4236from accelerate .utils import DistributedDataParallelKwargs , DistributedType , ProjectConfiguration , set_seed
43- from braceexpand import braceexpand
4437from datasets import load_dataset
4538from huggingface_hub import create_repo , upload_folder
46- from huggingface_hub .utils import insecure_hashlib
4739from packaging import version
48- from peft .utils import get_peft_model_state_dict
4940from PIL import Image
50- from PIL .ImageOps import exif_transpose
5141from safetensors .torch import load_file
5242from torch .nn .utils .spectral_norm import SpectralNorm
53- from torch .utils .data import default_collate , Dataset , DataLoader
54- from torchvision .transforms .functional import crop
43+ from torch .utils .data import DataLoader , Dataset
5544from tqdm .auto import tqdm
5645from transformers import AutoTokenizer , Gemma2Model
57- from typing import Callable , List , Union
58- from webdataset .tariterators import (
59- base_plus_ext ,
60- tar_file_expander ,
61- url_opener ,
62- valid_sample ,
63- )
6446
6547import diffusers
6648from diffusers import (
6749 AutoencoderDC ,
68- FlowMatchEulerDiscreteScheduler ,
6950 SanaPipeline ,
7051 SanaSprintPipeline ,
7152 SanaTransformer2DModel ,
72- SCMScheduler ,
7353)
7454from diffusers .optimization import get_scheduler
7555from diffusers .training_utils import (
76- cast_training_params ,
77- compute_density_for_timestep_sampling ,
78- compute_loss_weighting_for_sd3 ,
7956 free_memory ,
8057)
8158from diffusers .utils import (
8259 check_min_version ,
83- convert_unet_state_dict_to_peft ,
8460 is_wandb_available ,
8561)
8662from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
9874
9975if is_torch_npu_available ():
10076 torch .npu .config .allow_internal_format = False
101-
77+
10278COMPLEX_HUMAN_INSTRUCTION = [
10379 "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:" ,
10480 "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes." ,
10985 "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:" ,
11086 "User Prompt: " ,
11187]
112-
88+
11389
11490
11591class Text2ImageDataset (Dataset ):
@@ -140,17 +116,17 @@ def __init__(self, hf_dataset, resolution=1024):
140116
141117 def __len__ (self ):
142118 return len (self .dataset )
143-
119+
144120 def __getitem__ (self , idx ):
145121 item = self .dataset [idx ]
146122 text = item ['llava' ]
147123 image_bytes = item ['image' ]
148-
124+
149125 # Convert bytes to PIL Image
150126 image = Image .open (io .BytesIO (image_bytes ))
151-
127+
152128 image_tensor = self .transform (image )
153-
129+
154130 return {
155131 'text' : text ,
156132 'image' : image_tensor
@@ -768,7 +744,7 @@ def state_dict(self):
768744
769745 def __getattr__ (self , name ):
770746 return getattr (self .disc , name )
771-
747+
772748class SanaTrigFlow (SanaTransformer2DModel ):
773749 def __init__ (self , original_model , guidance = False ):
774750 self .__dict__ = original_model .__dict__
@@ -779,7 +755,7 @@ def __init__(self, original_model, guidance=False):
779755 self .logvar_linear = torch .nn .Linear (hidden_size , 1 )
780756 torch .nn .init .xavier_uniform_ (self .logvar_linear .weight )
781757 torch .nn .init .constant_ (self .logvar_linear .bias , 0 )
782-
758+
783759 def forward (self , hidden_states , encoder_hidden_states , timestep , guidance = None , jvp = False , return_logvar = False , ** kwargs ):
784760 batch_size = hidden_states .shape [0 ]
785761 latents = hidden_states
@@ -812,8 +788,8 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
812788 trigflow_model_out = ((1 - 2 * flow_timestep_expanded ) * latent_model_input + (1 - 2 * flow_timestep_expanded + 2 * flow_timestep_expanded ** 2 ) * model_out ) / torch .sqrt (
813789 flow_timestep_expanded ** 2 + (1 - flow_timestep_expanded ) ** 2
814790 )
815-
816-
791+
792+
817793 if self .guidance and guidance is not None :
818794 timestep , embedded_timestep = self .time_embed (
819795 timestep , guidance = guidance , hidden_dtype = hidden_states .dtype
@@ -822,15 +798,15 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
822798 timestep , embedded_timestep = self .time_embed (
823799 timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype
824800 )
825-
801+
826802 if return_logvar :
827803 logvar = self .logvar_linear (embedded_timestep )
828804 return trigflow_model_out , logvar
829-
805+
830806
831807 return (trigflow_model_out ,)
832808
833-
809+
834810
835811def compute_density_for_timestep_sampling_scm (
836812 batch_size : int , logit_mean : float = None , logit_std : float = None
@@ -925,19 +901,19 @@ def main(args):
925901 revision = args .revision ,
926902 variant = args .variant ,
927903 )
928-
904+
929905 ori_transformer = SanaTransformer2DModel .from_pretrained (
930906 args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant ,
931907 guidance_embeds = True , cross_attention_type = 'vanilla'
932908 )
933-
909+
934910 ori_transformer_no_guide = SanaTransformer2DModel .from_pretrained (
935911 args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant ,
936912 guidance_embeds = False
937913 )
938-
914+
939915 original_state_dict = load_file (f"{ args .pretrained_model_name_or_path } /transformer/diffusion_pytorch_model.safetensors" )
940-
916+
941917 param_mapping = {
942918 'time_embed.emb.timestep_embedder.linear_1.weight' : 'time_embed.timestep_embedder.linear_1.weight' ,
943919 'time_embed.emb.timestep_embedder.linear_1.bias' : 'time_embed.timestep_embedder.linear_1.bias' ,
@@ -968,7 +944,7 @@ def main(args):
968944
969945 transformer = SanaTrigFlow (ori_transformer , guidance = True ).train ()
970946 pretrained_model = SanaTrigFlow (ori_transformer_no_guide , guidance = False ).eval ()
971-
947+
972948 disc = SanaMSCMDiscriminator (
973949 pretrained_model ,
974950 is_multiscale = args .ladd_multi_scale ,
@@ -1134,7 +1110,7 @@ def load_model_hook(models, input_dir):
11341110 data_files = args .file_path ,
11351111 split = 'train' ,
11361112 )
1137-
1113+
11381114 train_dataset = Text2ImageDataset (
11391115 hf_dataset = hf_dataset ,
11401116 resolution = args .resolution ,
@@ -1282,8 +1258,8 @@ def load_model_hook(models, input_dir):
12821258 # Add noise according to TrigFlow.
12831259 # zt = cos(t) * x + sin(t) * noise
12841260 t = u .view (- 1 , 1 , 1 , 1 )
1285- noisy_model_input = torch .cos (t ) * model_input + torch .sin (t ) * noise
1286-
1261+ noisy_model_input = torch .cos (t ) * model_input + torch .sin (t ) * noise
1262+
12871263
12881264 scm_cfg_scale = torch .tensor (
12891265 np .random .choice (args .scm_cfg_scale , size = bsz , replace = True ),
@@ -1295,7 +1271,7 @@ def model_wrapper(scaled_x_t, t):
12951271 hidden_states = scaled_x_t , timestep = t .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask , guidance = (scm_cfg_scale .flatten () * args .guidance_embeds_scale ), jvp = True , return_logvar = True
12961272 )
12971273 return pred , logvar
1298-
1274+
12991275 if phase == "G" :
13001276 transformer .train ()
13011277 disc .eval ()
@@ -1322,8 +1298,8 @@ def model_wrapper(scaled_x_t, t):
13221298
13231299 v_x = torch .cos (t ) * torch .sin (t ) * dxt_dt / sigma_data
13241300 v_t = torch .cos (t ) * torch .sin (t )
1325-
1326-
1301+
1302+
13271303 # Adapt from https://github.com/xandergos/sCM-mnist/blob/master/train_consistency.py
13281304 with torch .no_grad ():
13291305 F_theta , F_theta_grad , logvar = torch .func .jvp (
@@ -1371,8 +1347,8 @@ def model_wrapper(scaled_x_t, t):
13711347 loss_no_logvar = loss_no_logvar .mean ()
13721348 loss_no_weight = l2_loss .mean ()
13731349 g_norm = g_norm .mean ()
1374-
1375-
1350+
1351+
13761352 pred_x_0 = torch .cos (t ) * noisy_model_input - torch .sin (t ) * F_theta * sigma_data
13771353
13781354 if args .train_largest_timestep :
@@ -1414,7 +1390,7 @@ def model_wrapper(scaled_x_t, t):
14141390 # Add noise to predicted x0
14151391 z_D = torch .randn_like (model_input ) * sigma_data
14161392 noised_predicted_x0 = torch .cos (t_D ) * pred_x_0 + torch .sin (t_D ) * z_D
1417-
1393+
14181394
14191395 # Calculate adversarial loss
14201396 pred_fake = disc (hidden_states = (noised_predicted_x0 / sigma_data ), timestep = t_D .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask )
@@ -1445,7 +1421,7 @@ def model_wrapper(scaled_x_t, t):
14451421 optimizer_G .step ()
14461422 lr_scheduler .step ()
14471423 optimizer_G .zero_grad (set_to_none = True )
1448-
1424+
14491425 elif phase == "D" :
14501426 transformer .eval ()
14511427 disc .train ()
@@ -1515,7 +1491,7 @@ def model_wrapper(scaled_x_t, t):
15151491
15161492
15171493 # Calculate D loss
1518-
1494+
15191495 pred_fake = disc (hidden_states = (noised_predicted_x0 / sigma_data ), timestep = t_D_fake .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask )
15201496 pred_true = disc (hidden_states = (noised_latents / sigma_data ), timestep = t_D_real .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask )
15211497
@@ -1542,7 +1518,7 @@ def model_wrapper(scaled_x_t, t):
15421518
15431519 optimizer_D .step ()
15441520 optimizer_D .zero_grad (set_to_none = True )
1545-
1521+
15461522
15471523 # Checks if the accelerator has performed an optimization step behind the scenes
15481524 if accelerator .sync_gradients :
@@ -1616,14 +1592,14 @@ def model_wrapper(scaled_x_t, t):
16161592 transformer .to (torch .float32 )
16171593 else :
16181594 transformer = transformer .to (weight_dtype )
1619-
1595+
16201596 # Save discriminator heads
16211597 disc = unwrap_model (disc )
16221598 disc_heads_state_dict = disc .heads .state_dict ()
1623-
1599+
16241600 # Save transformer model
16251601 transformer .save_pretrained (os .path .join (args .output_dir , "transformer" ))
1626-
1602+
16271603 # Save discriminator heads
16281604 torch .save (disc_heads_state_dict , os .path .join (args .output_dir , "disc_heads.pt" ))
16291605
@@ -1677,4 +1653,4 @@ def model_wrapper(scaled_x_t, t):
16771653
16781654if __name__ == "__main__" :
16791655 args = parse_args ()
1680- main (args )
1656+ main (args )
0 commit comments