1313import torch_xla .debug .metrics as met
1414
1515from diffusers import FluxPipeline
16+ import torch_xla .distributed .xla_multiprocessing as xmp
1617
1718logger = structlog .get_logger ()
1819metrics_filepath = '/tmp/metrics_report.txt'
1920
20- if __name__ == '__main__' :
21- parser = ArgumentParser ()
22- parser .add_argument ('--schnell' , action = 'store_true' , help = 'run flux schnell instead of dev' )
23- parser .add_argument ('--width' , type = int , default = 1024 , help = 'width of the image to generate' )
24- parser .add_argument ('--height' , type = int , default = 1024 , help = 'height of the image to generate' )
25- parser .add_argument ('--guidance' , type = float , default = 3.5 , help = 'gauidance strentgh for dev' )
26- parser .add_argument ('--seed' , type = int , default = None , help = 'seed for inference' )
27- parser .add_argument ('--profile' , action = 'store_true' , help = 'enable profiling' )
28- parser .add_argument ('--profile-duration' , type = int , default = 10000 , help = 'duration for profiling in msec.' )
29- args = parser .parse_args ()
21+ def _main (index , args , text_pipe , ckpt_id ):
3022
31- cache_path = Path ('/tmp/data/compiler_cache ' )
23+ cache_path = Path ('/tmp/data/compiler_cache_tRiLlium_eXp ' )
3224 cache_path .mkdir (parents = True , exist_ok = True )
3325 xr .initialize_cache (str (cache_path ), readonly = False )
3426
35- profile_path = Path ('/tmp/data/profiler_out ' )
27+ profile_path = Path ('/tmp/data/profiler_out_tRiLlium_eXp ' )
3628 profile_path .mkdir (parents = True , exist_ok = True )
3729 profiler_port = 9012
3830 profile_duration = args .profile_duration
3931 if args .profile :
4032 logger .info (f'starting profiler on port { profiler_port } ' )
4133 _ = xp .start_server (profiler_port )
34+ device0 = xm .xla_device ()
4235
43- device0 = xm .xla_device (0 )
44- device1 = xm .xla_device (1 )
45- logger .info (f'text encoders: { device0 } , flux: { device1 } ' )
46-
47- if args .schnell :
48- ckpt_id = "black-forest-labs/FLUX.1-schnell"
49- else :
50- ckpt_id = "black-forest-labs/FLUX.1-dev"
5136 logger .info (f'loading flux from { ckpt_id } ' )
52-
53- text_pipe = FluxPipeline .from_pretrained (ckpt_id , transformer = None , vae = None , torch_dtype = torch .bfloat16 ).to (device0 )
5437 flux_pipe = FluxPipeline .from_pretrained (ckpt_id , text_encoder = None , tokenizer = None ,
55- text_encoder_2 = None , tokenizer_2 = None , torch_dtype = torch .bfloat16 ).to (device1 )
38+ text_encoder_2 = None , tokenizer_2 = None , torch_dtype = torch .bfloat16 ).to (device0 )
5639
5740 prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side'
5841 width = args .width
6548 with torch .no_grad ():
6649 prompt_embeds , pooled_prompt_embeds , text_ids = text_pipe .encode_prompt (
6750 prompt = prompt , prompt_2 = None , max_sequence_length = 512 )
68- prompt_embeds = prompt_embeds .to (device1 )
69- pooled_prompt_embeds = pooled_prompt_embeds .to (device1 )
51+ prompt_embeds = prompt_embeds .to (device0 )
52+ pooled_prompt_embeds = pooled_prompt_embeds .to (device0 )
7053
7154 image = flux_pipe (prompt_embeds = prompt_embeds , pooled_prompt_embeds = pooled_prompt_embeds ,
7255 num_inference_steps = 28 , guidance_scale = guidance , height = height , width = width ).images [0 ]
7356 logger .info (f'compilation took { perf_counter () - ts } sec.' )
7457 image .save ('/tmp/compile_out.png' )
7558
76- seed = 0 if args .seed is None else args .seed
77- xm .set_rng_state (seed = seed , device = device0 )
78- xm .set_rng_state (seed = seed , device = device1 )
79-
59+ base_seed = 4096 if args .seed is None else args .seed
60+ seed_range = 1000
61+ unique_seed = base_seed + index * seed_range
62+ xm .set_rng_state (seed = unique_seed , device = device0 )
63+ times = []
8064 logger .info ('starting inference run...' )
81- ts = perf_counter ()
82- with torch .no_grad ():
83- prompt_embeds , pooled_prompt_embeds , text_ids = text_pipe .encode_prompt (
84- prompt = prompt , prompt_2 = None , max_sequence_length = 512 )
85- prompt_embeds = prompt_embeds .to (device1 )
86- pooled_prompt_embeds = pooled_prompt_embeds .to (device1 )
87- xm .wait_device_ops ()
88-
89- if args .profile :
90- xp .trace_detached (f"localhost:{ profiler_port } " , str (profile_path ), duration_ms = profile_duration )
91- image = flux_pipe (prompt_embeds = prompt_embeds , pooled_prompt_embeds = pooled_prompt_embeds ,
92- num_inference_steps = n_steps , guidance_scale = guidance , height = height , width = width ).images [0 ]
93- logger .info (f'inference took { perf_counter () - ts } sec.' )
94- image .save ('/tmp/inference_out.png' )
95- metrics_report = met .metrics_report ()
96- with open (metrics_filepath , 'w+' ) as fout :
97- fout .write (metrics_report )
98- logger .info (f'saved metric information as { metrics_filepath } ' )
65+ for _ in range (args .itters ):
66+ ts = perf_counter ()
67+ with torch .no_grad ():
68+ prompt_embeds , pooled_prompt_embeds , text_ids = text_pipe .encode_prompt (
69+ prompt = prompt , prompt_2 = None , max_sequence_length = 512 )
70+ prompt_embeds = prompt_embeds .to (device0 )
71+ pooled_prompt_embeds = pooled_prompt_embeds .to (device0 )
72+
73+ if args .profile :
74+ xp .trace_detached (f"localhost:{ profiler_port } " , str (profile_path ), duration_ms = profile_duration )
75+ image = flux_pipe (prompt_embeds = prompt_embeds , pooled_prompt_embeds = pooled_prompt_embeds ,
76+ num_inference_steps = n_steps , guidance_scale = guidance , height = height , width = width ).images [0 ]
77+ inference_time = perf_counter () - ts
78+ if index == 0 :
79+ logger .info (f"inference time: { inference_time } " )
80+ times .append (inference_time )
81+ logger .info (f'avg. inference over { args .itters } iterations took { sum (times )/ len (times )} sec.' )
82+ image .save (f'/home/tmp/inference_out-{ index } .png' )
83+ if index == 0 :
84+ metrics_report = met .metrics_report ()
85+ with open (metrics_filepath , 'w+' ) as fout :
86+ fout .write (metrics_report )
87+ logger .info (f'saved metric information as { metrics_filepath } ' )
9988
89+ if __name__ == '__main__' :
90+ parser = ArgumentParser ()
91+ parser .add_argument ('--schnell' , action = 'store_true' , help = 'run flux schnell instead of dev' )
92+ parser .add_argument ('--width' , type = int , default = 1024 , help = 'width of the image to generate' )
93+ parser .add_argument ('--height' , type = int , default = 1024 , help = 'height of the image to generate' )
94+ parser .add_argument ('--guidance' , type = float , default = 3.5 , help = 'gauidance strentgh for dev' )
95+ parser .add_argument ('--seed' , type = int , default = None , help = 'seed for inference' )
96+ parser .add_argument ('--profile' , action = 'store_true' , help = 'enable profiling' )
97+ parser .add_argument ('--profile-duration' , type = int , default = 10000 , help = 'duration for profiling in msec.' )
98+ parser .add_argument ('--itters' , type = int , default = 15 , help = 'tiems to run inference and get avg time in sec.' )
99+ args = parser .parse_args ()
100+ if args .schnell :
101+ ckpt_id = "black-forest-labs/FLUX.1-schnell"
102+ else :
103+ ckpt_id = "black-forest-labs/FLUX.1-dev"
104+ text_pipe = FluxPipeline .from_pretrained (ckpt_id , transformer = None , vae = None , torch_dtype = torch .bfloat16 ).to ('cpu' )
105+ xmp .spawn (_main , args = (args , text_pipe , ckpt_id ))
0 commit comments