Skip to content

Commit 358908d

Browse files
add parallel generations to flux ptxla code.
1 parent 9ae0379 commit 358908d

File tree

2 files changed

+93
-87
lines changed

2 files changed

+93
-87
lines changed

examples/research_projects/pytorch_xla/inference/flux/generate_flux.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,29 @@
1313
import torch_xla.debug.metrics as met
1414

1515
from diffusers import FluxPipeline
16+
import torch_xla.distributed.xla_multiprocessing as xmp
1617

1718
logger = structlog.get_logger()
1819
metrics_filepath = '/tmp/metrics_report.txt'
1920

20-
if __name__ == '__main__':
21-
parser = ArgumentParser()
22-
parser.add_argument('--schnell', action='store_true', help='run flux schnell instead of dev')
23-
parser.add_argument('--width', type=int, default=1024, help='width of the image to generate')
24-
parser.add_argument('--height', type=int, default=1024, help='height of the image to generate')
25-
parser.add_argument('--guidance', type=float, default=3.5, help='gauidance strentgh for dev')
26-
parser.add_argument('--seed', type=int, default=None, help='seed for inference')
27-
parser.add_argument('--profile', action='store_true', help='enable profiling')
28-
parser.add_argument('--profile-duration', type=int, default=10000, help='duration for profiling in msec.')
29-
args = parser.parse_args()
21+
def _main(index, args, text_pipe, ckpt_id):
3022

31-
cache_path = Path('/tmp/data/compiler_cache')
23+
cache_path = Path('/tmp/data/compiler_cache_tRiLlium_eXp')
3224
cache_path.mkdir(parents=True, exist_ok=True)
3325
xr.initialize_cache(str(cache_path), readonly=False)
3426

35-
profile_path = Path('/tmp/data/profiler_out')
27+
profile_path = Path('/tmp/data/profiler_out_tRiLlium_eXp')
3628
profile_path.mkdir(parents=True, exist_ok=True)
3729
profiler_port = 9012
3830
profile_duration = args.profile_duration
3931
if args.profile:
4032
logger.info(f'starting profiler on port {profiler_port}')
4133
_ = xp.start_server(profiler_port)
34+
device0 = xm.xla_device()
4235

43-
device0 = xm.xla_device(0)
44-
device1 = xm.xla_device(1)
45-
logger.info(f'text encoders: {device0}, flux: {device1}')
46-
47-
if args.schnell:
48-
ckpt_id = "black-forest-labs/FLUX.1-schnell"
49-
else:
50-
ckpt_id = "black-forest-labs/FLUX.1-dev"
5136
logger.info(f'loading flux from {ckpt_id}')
52-
53-
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to(device0)
5437
flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None,
55-
text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device1)
38+
text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device0)
5639

5740
prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side'
5841
width = args.width
@@ -65,35 +48,58 @@
6548
with torch.no_grad():
6649
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
6750
prompt=prompt, prompt_2=None, max_sequence_length=512)
68-
prompt_embeds = prompt_embeds.to(device1)
69-
pooled_prompt_embeds = pooled_prompt_embeds.to(device1)
51+
prompt_embeds = prompt_embeds.to(device0)
52+
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
7053

7154
image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
7255
num_inference_steps=28, guidance_scale=guidance, height=height, width=width).images[0]
7356
logger.info(f'compilation took {perf_counter() - ts} sec.')
7457
image.save('/tmp/compile_out.png')
7558

76-
seed = 0 if args.seed is None else args.seed
77-
xm.set_rng_state(seed=seed, device=device0)
78-
xm.set_rng_state(seed=seed, device=device1)
79-
59+
base_seed = 4096 if args.seed is None else args.seed
60+
seed_range = 1000
61+
unique_seed = base_seed + index * seed_range
62+
xm.set_rng_state(seed=unique_seed, device=device0)
63+
times = []
8064
logger.info('starting inference run...')
81-
ts = perf_counter()
82-
with torch.no_grad():
83-
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
84-
prompt=prompt, prompt_2=None, max_sequence_length=512)
85-
prompt_embeds = prompt_embeds.to(device1)
86-
pooled_prompt_embeds = pooled_prompt_embeds.to(device1)
87-
xm.wait_device_ops()
88-
89-
if args.profile:
90-
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
91-
image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
92-
num_inference_steps=n_steps, guidance_scale=guidance, height=height, width=width).images[0]
93-
logger.info(f'inference took {perf_counter() - ts} sec.')
94-
image.save('/tmp/inference_out.png')
95-
metrics_report = met.metrics_report()
96-
with open(metrics_filepath, 'w+') as fout:
97-
fout.write(metrics_report)
98-
logger.info(f'saved metric information as {metrics_filepath}')
65+
for _ in range(args.itters):
66+
ts = perf_counter()
67+
with torch.no_grad():
68+
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
69+
prompt=prompt, prompt_2=None, max_sequence_length=512)
70+
prompt_embeds = prompt_embeds.to(device0)
71+
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
72+
73+
if args.profile:
74+
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
75+
image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
76+
num_inference_steps=n_steps, guidance_scale=guidance, height=height, width=width).images[0]
77+
inference_time = perf_counter() - ts
78+
if index == 0:
79+
logger.info(f"inference time: {inference_time}")
80+
times.append(inference_time)
81+
logger.info(f'avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.')
82+
image.save(f'/home/tmp/inference_out-{index}.png')
83+
if index == 0:
84+
metrics_report = met.metrics_report()
85+
with open(metrics_filepath, 'w+') as fout:
86+
fout.write(metrics_report)
87+
logger.info(f'saved metric information as {metrics_filepath}')
9988

89+
if __name__ == '__main__':
90+
parser = ArgumentParser()
91+
parser.add_argument('--schnell', action='store_true', help='run flux schnell instead of dev')
92+
parser.add_argument('--width', type=int, default=1024, help='width of the image to generate')
93+
parser.add_argument('--height', type=int, default=1024, help='height of the image to generate')
94+
parser.add_argument('--guidance', type=float, default=3.5, help='gauidance strentgh for dev')
95+
parser.add_argument('--seed', type=int, default=None, help='seed for inference')
96+
parser.add_argument('--profile', action='store_true', help='enable profiling')
97+
parser.add_argument('--profile-duration', type=int, default=10000, help='duration for profiling in msec.')
98+
parser.add_argument('--itters', type=int, default=15, help='tiems to run inference and get avg time in sec.')
99+
args = parser.parse_args()
100+
if args.schnell:
101+
ckpt_id = "black-forest-labs/FLUX.1-schnell"
102+
else:
103+
ckpt_id = "black-forest-labs/FLUX.1-dev"
104+
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to('cpu')
105+
xmp.spawn(_main, args=(args, text_pipe, ckpt_id))

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -708,48 +708,48 @@ def __call__(
708708
guidance = None
709709

710710
# 6. Denoising loop
711-
with self.progress_bar(total=num_inference_steps) as progress_bar:
712-
for i, t in enumerate(timesteps):
713-
if self.interrupt:
714-
continue
715-
716-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
718-
719-
noise_pred = self.transformer(
720-
hidden_states=latents,
721-
timestep=timestep / 1000,
722-
guidance=guidance,
723-
pooled_projections=pooled_prompt_embeds,
724-
encoder_hidden_states=prompt_embeds,
725-
txt_ids=text_ids,
726-
img_ids=latent_image_ids,
727-
joint_attention_kwargs=self.joint_attention_kwargs,
728-
return_dict=False,
729-
)[0]
730-
731-
# compute the previous noisy sample x_t -> x_t-1
732-
latents_dtype = latents.dtype
733-
734-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
735-
736-
if latents.dtype != latents_dtype:
737-
if torch.backends.mps.is_available():
738-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739-
latents = latents.to(latents_dtype)
740-
741-
if callback_on_step_end is not None:
742-
callback_kwargs = {}
743-
for k in callback_on_step_end_tensor_inputs:
744-
callback_kwargs[k] = locals()[k]
745-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
746-
747-
latents = callback_outputs.pop("latents", latents)
748-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
711+
#with self.progress_bar(total=num_inference_steps) as progress_bar:
712+
for i, t in enumerate(timesteps):
713+
if self.interrupt:
714+
continue
715+
716+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
718+
719+
noise_pred = self.transformer(
720+
hidden_states=latents,
721+
timestep=timestep / 1000,
722+
guidance=guidance,
723+
pooled_projections=pooled_prompt_embeds,
724+
encoder_hidden_states=prompt_embeds,
725+
txt_ids=text_ids,
726+
img_ids=latent_image_ids,
727+
joint_attention_kwargs=self.joint_attention_kwargs,
728+
return_dict=False,
729+
)[0]
730+
731+
# compute the previous noisy sample x_t -> x_t-1
732+
latents_dtype = latents.dtype
733+
734+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
735+
736+
if latents.dtype != latents_dtype:
737+
if torch.backends.mps.is_available():
738+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739+
latents = latents.to(latents_dtype)
740+
741+
if callback_on_step_end is not None:
742+
callback_kwargs = {}
743+
for k in callback_on_step_end_tensor_inputs:
744+
callback_kwargs[k] = locals()[k]
745+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
746+
747+
latents = callback_outputs.pop("latents", latents)
748+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
749749

750750
# call the callback, if provided
751-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
752-
progress_bar.update()
751+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
752+
# progress_bar.update()
753753

754754
if XLA_AVAILABLE:
755755
xm.mark_step()

0 commit comments

Comments
 (0)