Skip to content

Commit 9ae0379

Browse files
upload generate file.
1 parent 9627cc1 commit 9ae0379

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from time import perf_counter
2+
from pathlib import Path
3+
from argparse import ArgumentParser
4+
5+
import structlog
6+
import time
7+
8+
import torch
9+
import torch_xla
10+
import torch_xla.core.xla_model as xm
11+
import torch_xla.runtime as xr
12+
import torch_xla.debug.profiler as xp
13+
import torch_xla.debug.metrics as met
14+
15+
from diffusers import FluxPipeline
16+
17+
logger = structlog.get_logger()
18+
metrics_filepath = '/tmp/metrics_report.txt'
19+
20+
if __name__ == '__main__':
21+
parser = ArgumentParser()
22+
parser.add_argument('--schnell', action='store_true', help='run flux schnell instead of dev')
23+
parser.add_argument('--width', type=int, default=1024, help='width of the image to generate')
24+
parser.add_argument('--height', type=int, default=1024, help='height of the image to generate')
25+
parser.add_argument('--guidance', type=float, default=3.5, help='gauidance strentgh for dev')
26+
parser.add_argument('--seed', type=int, default=None, help='seed for inference')
27+
parser.add_argument('--profile', action='store_true', help='enable profiling')
28+
parser.add_argument('--profile-duration', type=int, default=10000, help='duration for profiling in msec.')
29+
args = parser.parse_args()
30+
31+
cache_path = Path('/tmp/data/compiler_cache')
32+
cache_path.mkdir(parents=True, exist_ok=True)
33+
xr.initialize_cache(str(cache_path), readonly=False)
34+
35+
profile_path = Path('/tmp/data/profiler_out')
36+
profile_path.mkdir(parents=True, exist_ok=True)
37+
profiler_port = 9012
38+
profile_duration = args.profile_duration
39+
if args.profile:
40+
logger.info(f'starting profiler on port {profiler_port}')
41+
_ = xp.start_server(profiler_port)
42+
43+
device0 = xm.xla_device(0)
44+
device1 = xm.xla_device(1)
45+
logger.info(f'text encoders: {device0}, flux: {device1}')
46+
47+
if args.schnell:
48+
ckpt_id = "black-forest-labs/FLUX.1-schnell"
49+
else:
50+
ckpt_id = "black-forest-labs/FLUX.1-dev"
51+
logger.info(f'loading flux from {ckpt_id}')
52+
53+
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to(device0)
54+
flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None,
55+
text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device1)
56+
57+
prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side'
58+
width = args.width
59+
height = args.height
60+
guidance = args.guidance
61+
n_steps = 4 if args.schnell else 28
62+
63+
logger.info('starting compilation run...')
64+
ts = perf_counter()
65+
with torch.no_grad():
66+
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
67+
prompt=prompt, prompt_2=None, max_sequence_length=512)
68+
prompt_embeds = prompt_embeds.to(device1)
69+
pooled_prompt_embeds = pooled_prompt_embeds.to(device1)
70+
71+
image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
72+
num_inference_steps=28, guidance_scale=guidance, height=height, width=width).images[0]
73+
logger.info(f'compilation took {perf_counter() - ts} sec.')
74+
image.save('/tmp/compile_out.png')
75+
76+
seed = 0 if args.seed is None else args.seed
77+
xm.set_rng_state(seed=seed, device=device0)
78+
xm.set_rng_state(seed=seed, device=device1)
79+
80+
logger.info('starting inference run...')
81+
ts = perf_counter()
82+
with torch.no_grad():
83+
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
84+
prompt=prompt, prompt_2=None, max_sequence_length=512)
85+
prompt_embeds = prompt_embeds.to(device1)
86+
pooled_prompt_embeds = pooled_prompt_embeds.to(device1)
87+
xm.wait_device_ops()
88+
89+
if args.profile:
90+
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
91+
image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
92+
num_inference_steps=n_steps, guidance_scale=guidance, height=height, width=width).images[0]
93+
logger.info(f'inference took {perf_counter() - ts} sec.')
94+
image.save('/tmp/inference_out.png')
95+
metrics_report = met.metrics_report()
96+
with open(metrics_filepath, 'w+') as fout:
97+
fout.write(metrics_report)
98+
logger.info(f'saved metric information as {metrics_filepath}')
99+

0 commit comments

Comments
 (0)