1- from  time  import  perf_counter 
2- from  pathlib  import  Path 
31from  argparse  import  ArgumentParser 
2+ from  pathlib  import  Path 
3+ from  time  import  perf_counter 
44
55import  structlog 
6- 
76import  torch 
87import  torch_xla .core .xla_model  as  xm 
9- import  torch_xla .runtime  as  xr 
10- import  torch_xla .debug .profiler  as  xp 
118import  torch_xla .debug .metrics  as  met 
12- from   diffusers   import   FluxPipeline 
9+ import   torch_xla . debug . profiler   as   xp 
1310import  torch_xla .distributed .xla_multiprocessing  as  xmp 
11+ import  torch_xla .runtime  as  xr 
12+ 
13+ from  diffusers  import  FluxPipeline 
14+ 
1415
1516logger  =  structlog .get_logger ()
16- metrics_filepath  =  ' /tmp/metrics_report.txt' 
17+ metrics_filepath  =  " /tmp/metrics_report.txt" 
1718
18- def  _main (index , args , text_pipe , ckpt_id ):
1919
20-     cache_path  =  Path ('/tmp/data/compiler_cache_tRiLlium_eXp' )
20+ def  _main (index , args , text_pipe , ckpt_id ):
21+     cache_path  =  Path ("/tmp/data/compiler_cache_tRiLlium_eXp" )
2122    cache_path .mkdir (parents = True , exist_ok = True )
2223    xr .initialize_cache (str (cache_path ), readonly = False )
2324
24-     profile_path  =  Path (' /tmp/data/profiler_out_tRiLlium_eXp' 
25+     profile_path  =  Path (" /tmp/data/profiler_out_tRiLlium_eXp" 
2526    profile_path .mkdir (parents = True , exist_ok = True )
2627    profiler_port  =  9012 
2728    profile_duration  =  args .profile_duration 
2829    if  args .profile :
29-         logger .info (f' starting profiler on port { profiler_port } '  )
30+         logger .info (f" starting profiler on port { profiler_port } "  )
3031        _  =  xp .start_server (profiler_port )
3132    device0  =  xm .xla_device ()
3233
33-     logger .info (f'loading flux from { ckpt_id }  )
34-     flux_pipe  =  FluxPipeline .from_pretrained (ckpt_id , text_encoder = None , tokenizer = None ,
35-                                             text_encoder_2 = None , tokenizer_2 = None , torch_dtype = torch .bfloat16 ).to (device0 )
34+     logger .info (f"loading flux from { ckpt_id }  )
35+     flux_pipe  =  FluxPipeline .from_pretrained (
36+         ckpt_id , text_encoder = None , tokenizer = None , text_encoder_2 = None , tokenizer_2 = None , torch_dtype = torch .bfloat16 
37+     ).to (device0 )
3638    flux_pipe .transformer .enable_xla_flash_attention (partition_spec = ("data" , None , None , None ), is_flux = True )
3739
38-     prompt  =  ' photograph of an electronics chip in the shape of a race car with trillium written on its side' 
40+     prompt  =  " photograph of an electronics chip in the shape of a race car with trillium written on its side" 
3941    width  =  args .width 
4042    height  =  args .height 
4143    guidance  =  args .guidance 
4244    n_steps  =  4  if  args .schnell  else  28 
4345
44-     logger .info (' starting compilation run...' 
46+     logger .info (" starting compilation run..." 
4547    ts  =  perf_counter ()
4648    with  torch .no_grad ():
4749        prompt_embeds , pooled_prompt_embeds , text_ids  =  text_pipe .encode_prompt (
48-             prompt = prompt , prompt_2 = None , max_sequence_length = 512 )
50+             prompt = prompt , prompt_2 = None , max_sequence_length = 512 
51+         )
4952    prompt_embeds  =  prompt_embeds .to (device0 )
5053    pooled_prompt_embeds  =  pooled_prompt_embeds .to (device0 )
5154
52-     image  =  flux_pipe (prompt_embeds = prompt_embeds , pooled_prompt_embeds = pooled_prompt_embeds ,
53-                     num_inference_steps = 28 , guidance_scale = guidance , height = height , width = width ).images [0 ]
54-     logger .info (f'compilation took { perf_counter () -  ts }  )
55-     image .save ('/tmp/compile_out.png' )
55+     image  =  flux_pipe (
56+         prompt_embeds = prompt_embeds ,
57+         pooled_prompt_embeds = pooled_prompt_embeds ,
58+         num_inference_steps = 28 ,
59+         guidance_scale = guidance ,
60+         height = height ,
61+         width = width ,
62+     ).images [0 ]
63+     logger .info (f"compilation took { perf_counter () -  ts }  )
64+     image .save ("/tmp/compile_out.png" )
5665
5766    base_seed  =  4096  if  args .seed  is  None  else  args .seed 
5867    seed_range  =  1000 
5968    unique_seed  =  base_seed  +  index  *  seed_range 
6069    xm .set_rng_state (seed = unique_seed , device = device0 )
6170    times  =  []
62-     logger .info (' starting inference run...' 
71+     logger .info (" starting inference run..." 
6372    for  _  in  range (args .itters ):
6473        ts  =  perf_counter ()
6574        with  torch .no_grad ():
6675            prompt_embeds , pooled_prompt_embeds , text_ids  =  text_pipe .encode_prompt (
67-                 prompt = prompt , prompt_2 = None , max_sequence_length = 512 )
76+                 prompt = prompt , prompt_2 = None , max_sequence_length = 512 
77+             )
6878        prompt_embeds  =  prompt_embeds .to (device0 )
6979        pooled_prompt_embeds  =  pooled_prompt_embeds .to (device0 )
7080
7181        if  args .profile :
7282            xp .trace_detached (f"localhost:{ profiler_port }  , str (profile_path ), duration_ms = profile_duration )
73-         image  =  flux_pipe (prompt_embeds = prompt_embeds , pooled_prompt_embeds = pooled_prompt_embeds ,
74-                         num_inference_steps = n_steps , guidance_scale = guidance , height = height , width = width ).images [0 ]
83+         image  =  flux_pipe (
84+             prompt_embeds = prompt_embeds ,
85+             pooled_prompt_embeds = pooled_prompt_embeds ,
86+             num_inference_steps = n_steps ,
87+             guidance_scale = guidance ,
88+             height = height ,
89+             width = width ,
90+         ).images [0 ]
7591        inference_time  =  perf_counter () -  ts 
7692        if  index  ==  0 :
7793            logger .info (f"inference time: { inference_time }  )
7894        times .append (inference_time )
79-     logger .info (f' avg. inference over { args .itters } { sum (times )/ len (times )} '  )
80-     image .save (f' /tmp/inference_out-{ index } '  )
95+     logger .info (f" avg. inference over { args .itters } { sum (times )/ len (times )} "  )
96+     image .save (f" /tmp/inference_out-{ index } "  )
8197    if  index  ==  0 :
8298        metrics_report  =  met .metrics_report ()
83-         with  open (metrics_filepath , 'w+' ) as  fout :
99+         with  open (metrics_filepath , "w+" ) as  fout :
84100            fout .write (metrics_report )
85-         logger .info (f'saved metric information as { metrics_filepath }  )
101+         logger .info (f"saved metric information as { metrics_filepath }  )
102+ 
86103
87- if  __name__  ==  ' __main__' 
104+ if  __name__  ==  " __main__" 
88105    parser  =  ArgumentParser ()
89-     parser .add_argument (' --schnell' action = ' store_true' help = ' run flux schnell instead of dev' 
90-     parser .add_argument (' --width' type = int , default = 1024 , help = ' width of the image to generate' 
91-     parser .add_argument (' --height' type = int , default = 1024 , help = ' height of the image to generate' 
92-     parser .add_argument (' --guidance' type = float , default = 3.5 , help = ' gauidance strentgh for dev' 
93-     parser .add_argument (' --seed' type = int , default = None , help = ' seed for inference' 
94-     parser .add_argument (' --profile' action = ' store_true' help = ' enable profiling' 
95-     parser .add_argument (' --profile-duration' type = int , default = 10000 , help = ' duration for profiling in msec.' 
96-     parser .add_argument (' --itters' type = int , default = 15 , help = ' tiems to run inference and get avg time in sec.' 
106+     parser .add_argument (" --schnell" action = " store_true" help = " run flux schnell instead of dev" 
107+     parser .add_argument (" --width" type = int , default = 1024 , help = " width of the image to generate" 
108+     parser .add_argument (" --height" type = int , default = 1024 , help = " height of the image to generate" 
109+     parser .add_argument (" --guidance" type = float , default = 3.5 , help = " gauidance strentgh for dev" 
110+     parser .add_argument (" --seed" type = int , default = None , help = " seed for inference" 
111+     parser .add_argument (" --profile" action = " store_true" help = " enable profiling" 
112+     parser .add_argument (" --profile-duration" type = int , default = 10000 , help = " duration for profiling in msec." 
113+     parser .add_argument (" --itters" type = int , default = 15 , help = " tiems to run inference and get avg time in sec." 
97114    args  =  parser .parse_args ()
98115    if  args .schnell :
99116        ckpt_id  =  "black-forest-labs/FLUX.1-schnell" 
100117    else :
101118        ckpt_id  =  "black-forest-labs/FLUX.1-dev" 
102-     text_pipe  =  FluxPipeline .from_pretrained (ckpt_id , transformer = None , vae = None , torch_dtype = torch .bfloat16 ).to (' cpu' 
103-     xmp .spawn (_main , args = (args , text_pipe , ckpt_id ))
119+     text_pipe  =  FluxPipeline .from_pretrained (ckpt_id , transformer = None , vae = None , torch_dtype = torch .bfloat16 ).to (" cpu" 
120+     xmp .spawn (_main , args = (args , text_pipe , ckpt_id ))
0 commit comments