|
| 1 | +from time import perf_counter |
| 2 | +from pathlib import Path |
| 3 | +from argparse import ArgumentParser |
| 4 | + |
| 5 | +import structlog |
| 6 | +import time |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch_xla |
| 10 | +import torch_xla.core.xla_model as xm |
| 11 | +import torch_xla.runtime as xr |
| 12 | +import torch_xla.debug.profiler as xp |
| 13 | +import torch_xla.debug.metrics as met |
| 14 | + |
| 15 | +from diffusers import FluxPipeline |
| 16 | + |
| 17 | +logger = structlog.get_logger() |
| 18 | +metrics_filepath = '/tmp/metrics_report.txt' |
| 19 | + |
| 20 | +if __name__ == '__main__': |
| 21 | + parser = ArgumentParser() |
| 22 | + parser.add_argument('--schnell', action='store_true', help='run flux schnell instead of dev') |
| 23 | + parser.add_argument('--width', type=int, default=1024, help='width of the image to generate') |
| 24 | + parser.add_argument('--height', type=int, default=1024, help='height of the image to generate') |
| 25 | + parser.add_argument('--guidance', type=float, default=3.5, help='gauidance strentgh for dev') |
| 26 | + parser.add_argument('--seed', type=int, default=None, help='seed for inference') |
| 27 | + parser.add_argument('--profile', action='store_true', help='enable profiling') |
| 28 | + parser.add_argument('--profile-duration', type=int, default=10000, help='duration for profiling in msec.') |
| 29 | + args = parser.parse_args() |
| 30 | + |
| 31 | + cache_path = Path('/tmp/data/compiler_cache') |
| 32 | + cache_path.mkdir(parents=True, exist_ok=True) |
| 33 | + xr.initialize_cache(str(cache_path), readonly=False) |
| 34 | + |
| 35 | + profile_path = Path('/tmp/data/profiler_out') |
| 36 | + profile_path.mkdir(parents=True, exist_ok=True) |
| 37 | + profiler_port = 9012 |
| 38 | + profile_duration = args.profile_duration |
| 39 | + if args.profile: |
| 40 | + logger.info(f'starting profiler on port {profiler_port}') |
| 41 | + _ = xp.start_server(profiler_port) |
| 42 | + |
| 43 | + device0 = xm.xla_device(0) |
| 44 | + device1 = xm.xla_device(1) |
| 45 | + logger.info(f'text encoders: {device0}, flux: {device1}') |
| 46 | + |
| 47 | + if args.schnell: |
| 48 | + ckpt_id = "black-forest-labs/FLUX.1-schnell" |
| 49 | + else: |
| 50 | + ckpt_id = "black-forest-labs/FLUX.1-dev" |
| 51 | + logger.info(f'loading flux from {ckpt_id}') |
| 52 | + |
| 53 | + text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to(device0) |
| 54 | + flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None, |
| 55 | + text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device1) |
| 56 | + |
| 57 | + prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side' |
| 58 | + width = args.width |
| 59 | + height = args.height |
| 60 | + guidance = args.guidance |
| 61 | + n_steps = 4 if args.schnell else 28 |
| 62 | + |
| 63 | + logger.info('starting compilation run...') |
| 64 | + ts = perf_counter() |
| 65 | + with torch.no_grad(): |
| 66 | + prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt( |
| 67 | + prompt=prompt, prompt_2=None, max_sequence_length=512) |
| 68 | + prompt_embeds = prompt_embeds.to(device1) |
| 69 | + pooled_prompt_embeds = pooled_prompt_embeds.to(device1) |
| 70 | + |
| 71 | + image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, |
| 72 | + num_inference_steps=28, guidance_scale=guidance, height=height, width=width).images[0] |
| 73 | + logger.info(f'compilation took {perf_counter() - ts} sec.') |
| 74 | + image.save('/tmp/compile_out.png') |
| 75 | + |
| 76 | + seed = 0 if args.seed is None else args.seed |
| 77 | + xm.set_rng_state(seed=seed, device=device0) |
| 78 | + xm.set_rng_state(seed=seed, device=device1) |
| 79 | + |
| 80 | + logger.info('starting inference run...') |
| 81 | + ts = perf_counter() |
| 82 | + with torch.no_grad(): |
| 83 | + prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt( |
| 84 | + prompt=prompt, prompt_2=None, max_sequence_length=512) |
| 85 | + prompt_embeds = prompt_embeds.to(device1) |
| 86 | + pooled_prompt_embeds = pooled_prompt_embeds.to(device1) |
| 87 | + xm.wait_device_ops() |
| 88 | + |
| 89 | + if args.profile: |
| 90 | + xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration) |
| 91 | + image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, |
| 92 | + num_inference_steps=n_steps, guidance_scale=guidance, height=height, width=width).images[0] |
| 93 | + logger.info(f'inference took {perf_counter() - ts} sec.') |
| 94 | + image.save('/tmp/inference_out.png') |
| 95 | + metrics_report = met.metrics_report() |
| 96 | + with open(metrics_filepath, 'w+') as fout: |
| 97 | + fout.write(metrics_report) |
| 98 | + logger.info(f'saved metric information as {metrics_filepath}') |
| 99 | + |
0 commit comments