1- from time import perf_counter
2- from pathlib import Path
31from argparse import ArgumentParser
2+ from pathlib import Path
3+ from time import perf_counter
44
55import structlog
6-
76import torch
87import torch_xla .core .xla_model as xm
9- import torch_xla .runtime as xr
10- import torch_xla .debug .profiler as xp
118import torch_xla .debug .metrics as met
12- from diffusers import FluxPipeline
9+ import torch_xla . debug . profiler as xp
1310import torch_xla .distributed .xla_multiprocessing as xmp
11+ import torch_xla .runtime as xr
12+
13+ from diffusers import FluxPipeline
14+
1415
1516logger = structlog .get_logger ()
16- metrics_filepath = ' /tmp/metrics_report.txt'
17+ metrics_filepath = " /tmp/metrics_report.txt"
1718
18- def _main (index , args , text_pipe , ckpt_id ):
1919
20- cache_path = Path ('/tmp/data/compiler_cache_tRiLlium_eXp' )
20+ def _main (index , args , text_pipe , ckpt_id ):
21+ cache_path = Path ("/tmp/data/compiler_cache_tRiLlium_eXp" )
2122 cache_path .mkdir (parents = True , exist_ok = True )
2223 xr .initialize_cache (str (cache_path ), readonly = False )
2324
24- profile_path = Path (' /tmp/data/profiler_out_tRiLlium_eXp' )
25+ profile_path = Path (" /tmp/data/profiler_out_tRiLlium_eXp" )
2526 profile_path .mkdir (parents = True , exist_ok = True )
2627 profiler_port = 9012
2728 profile_duration = args .profile_duration
2829 if args .profile :
29- logger .info (f' starting profiler on port { profiler_port } ' )
30+ logger .info (f" starting profiler on port { profiler_port } " )
3031 _ = xp .start_server (profiler_port )
3132 device0 = xm .xla_device ()
3233
33- logger .info (f'loading flux from { ckpt_id } ' )
34- flux_pipe = FluxPipeline .from_pretrained (ckpt_id , text_encoder = None , tokenizer = None ,
35- text_encoder_2 = None , tokenizer_2 = None , torch_dtype = torch .bfloat16 ).to (device0 )
34+ logger .info (f"loading flux from { ckpt_id } " )
35+ flux_pipe = FluxPipeline .from_pretrained (
36+ ckpt_id , text_encoder = None , tokenizer = None , text_encoder_2 = None , tokenizer_2 = None , torch_dtype = torch .bfloat16
37+ ).to (device0 )
3638 flux_pipe .transformer .enable_xla_flash_attention (partition_spec = ("data" , None , None , None ), is_flux = True )
3739
38- prompt = ' photograph of an electronics chip in the shape of a race car with trillium written on its side'
40+ prompt = " photograph of an electronics chip in the shape of a race car with trillium written on its side"
3941 width = args .width
4042 height = args .height
4143 guidance = args .guidance
4244 n_steps = 4 if args .schnell else 28
4345
44- logger .info (' starting compilation run...' )
46+ logger .info (" starting compilation run..." )
4547 ts = perf_counter ()
4648 with torch .no_grad ():
4749 prompt_embeds , pooled_prompt_embeds , text_ids = text_pipe .encode_prompt (
48- prompt = prompt , prompt_2 = None , max_sequence_length = 512 )
50+ prompt = prompt , prompt_2 = None , max_sequence_length = 512
51+ )
4952 prompt_embeds = prompt_embeds .to (device0 )
5053 pooled_prompt_embeds = pooled_prompt_embeds .to (device0 )
5154
52- image = flux_pipe (prompt_embeds = prompt_embeds , pooled_prompt_embeds = pooled_prompt_embeds ,
53- num_inference_steps = 28 , guidance_scale = guidance , height = height , width = width ).images [0 ]
54- logger .info (f'compilation took { perf_counter () - ts } sec.' )
55- image .save ('/tmp/compile_out.png' )
55+ image = flux_pipe (
56+ prompt_embeds = prompt_embeds ,
57+ pooled_prompt_embeds = pooled_prompt_embeds ,
58+ num_inference_steps = 28 ,
59+ guidance_scale = guidance ,
60+ height = height ,
61+ width = width ,
62+ ).images [0 ]
63+ logger .info (f"compilation took { perf_counter () - ts } sec." )
64+ image .save ("/tmp/compile_out.png" )
5665
5766 base_seed = 4096 if args .seed is None else args .seed
5867 seed_range = 1000
5968 unique_seed = base_seed + index * seed_range
6069 xm .set_rng_state (seed = unique_seed , device = device0 )
6170 times = []
62- logger .info (' starting inference run...' )
71+ logger .info (" starting inference run..." )
6372 for _ in range (args .itters ):
6473 ts = perf_counter ()
6574 with torch .no_grad ():
6675 prompt_embeds , pooled_prompt_embeds , text_ids = text_pipe .encode_prompt (
67- prompt = prompt , prompt_2 = None , max_sequence_length = 512 )
76+ prompt = prompt , prompt_2 = None , max_sequence_length = 512
77+ )
6878 prompt_embeds = prompt_embeds .to (device0 )
6979 pooled_prompt_embeds = pooled_prompt_embeds .to (device0 )
7080
7181 if args .profile :
7282 xp .trace_detached (f"localhost:{ profiler_port } " , str (profile_path ), duration_ms = profile_duration )
73- image = flux_pipe (prompt_embeds = prompt_embeds , pooled_prompt_embeds = pooled_prompt_embeds ,
74- num_inference_steps = n_steps , guidance_scale = guidance , height = height , width = width ).images [0 ]
83+ image = flux_pipe (
84+ prompt_embeds = prompt_embeds ,
85+ pooled_prompt_embeds = pooled_prompt_embeds ,
86+ num_inference_steps = n_steps ,
87+ guidance_scale = guidance ,
88+ height = height ,
89+ width = width ,
90+ ).images [0 ]
7591 inference_time = perf_counter () - ts
7692 if index == 0 :
7793 logger .info (f"inference time: { inference_time } " )
7894 times .append (inference_time )
79- logger .info (f' avg. inference over { args .itters } iterations took { sum (times )/ len (times )} sec.' )
80- image .save (f' /tmp/inference_out-{ index } .png' )
95+ logger .info (f" avg. inference over { args .itters } iterations took { sum (times )/ len (times )} sec." )
96+ image .save (f" /tmp/inference_out-{ index } .png" )
8197 if index == 0 :
8298 metrics_report = met .metrics_report ()
83- with open (metrics_filepath , 'w+' ) as fout :
99+ with open (metrics_filepath , "w+" ) as fout :
84100 fout .write (metrics_report )
85- logger .info (f'saved metric information as { metrics_filepath } ' )
101+ logger .info (f"saved metric information as { metrics_filepath } " )
102+
86103
87- if __name__ == ' __main__' :
104+ if __name__ == " __main__" :
88105 parser = ArgumentParser ()
89- parser .add_argument (' --schnell' , action = ' store_true' , help = ' run flux schnell instead of dev' )
90- parser .add_argument (' --width' , type = int , default = 1024 , help = ' width of the image to generate' )
91- parser .add_argument (' --height' , type = int , default = 1024 , help = ' height of the image to generate' )
92- parser .add_argument (' --guidance' , type = float , default = 3.5 , help = ' gauidance strentgh for dev' )
93- parser .add_argument (' --seed' , type = int , default = None , help = ' seed for inference' )
94- parser .add_argument (' --profile' , action = ' store_true' , help = ' enable profiling' )
95- parser .add_argument (' --profile-duration' , type = int , default = 10000 , help = ' duration for profiling in msec.' )
96- parser .add_argument (' --itters' , type = int , default = 15 , help = ' tiems to run inference and get avg time in sec.' )
106+ parser .add_argument (" --schnell" , action = " store_true" , help = " run flux schnell instead of dev" )
107+ parser .add_argument (" --width" , type = int , default = 1024 , help = " width of the image to generate" )
108+ parser .add_argument (" --height" , type = int , default = 1024 , help = " height of the image to generate" )
109+ parser .add_argument (" --guidance" , type = float , default = 3.5 , help = " gauidance strentgh for dev" )
110+ parser .add_argument (" --seed" , type = int , default = None , help = " seed for inference" )
111+ parser .add_argument (" --profile" , action = " store_true" , help = " enable profiling" )
112+ parser .add_argument (" --profile-duration" , type = int , default = 10000 , help = " duration for profiling in msec." )
113+ parser .add_argument (" --itters" , type = int , default = 15 , help = " tiems to run inference and get avg time in sec." )
97114 args = parser .parse_args ()
98115 if args .schnell :
99116 ckpt_id = "black-forest-labs/FLUX.1-schnell"
100117 else :
101118 ckpt_id = "black-forest-labs/FLUX.1-dev"
102- text_pipe = FluxPipeline .from_pretrained (ckpt_id , transformer = None , vae = None , torch_dtype = torch .bfloat16 ).to (' cpu' )
103- xmp .spawn (_main , args = (args , text_pipe , ckpt_id ))
119+ text_pipe = FluxPipeline .from_pretrained (ckpt_id , transformer = None , vae = None , torch_dtype = torch .bfloat16 ).to (" cpu" )
120+ xmp .spawn (_main , args = (args , text_pipe , ckpt_id ))
0 commit comments