Skip to content

Commit ad1e1cf

Browse files
run make style/quality
1 parent 378f7ed commit ad1e1cf

File tree

2 files changed

+62
-45
lines changed

2 files changed

+62
-45
lines changed
Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,120 @@
1-
from time import perf_counter
2-
from pathlib import Path
31
from argparse import ArgumentParser
2+
from pathlib import Path
3+
from time import perf_counter
44

55
import structlog
6-
76
import torch
87
import torch_xla.core.xla_model as xm
9-
import torch_xla.runtime as xr
10-
import torch_xla.debug.profiler as xp
118
import torch_xla.debug.metrics as met
12-
from diffusers import FluxPipeline
9+
import torch_xla.debug.profiler as xp
1310
import torch_xla.distributed.xla_multiprocessing as xmp
11+
import torch_xla.runtime as xr
12+
13+
from diffusers import FluxPipeline
14+
1415

1516
logger = structlog.get_logger()
16-
metrics_filepath = '/tmp/metrics_report.txt'
17+
metrics_filepath = "/tmp/metrics_report.txt"
1718

18-
def _main(index, args, text_pipe, ckpt_id):
1919

20-
cache_path = Path('/tmp/data/compiler_cache_tRiLlium_eXp')
20+
def _main(index, args, text_pipe, ckpt_id):
21+
cache_path = Path("/tmp/data/compiler_cache_tRiLlium_eXp")
2122
cache_path.mkdir(parents=True, exist_ok=True)
2223
xr.initialize_cache(str(cache_path), readonly=False)
2324

24-
profile_path = Path('/tmp/data/profiler_out_tRiLlium_eXp')
25+
profile_path = Path("/tmp/data/profiler_out_tRiLlium_eXp")
2526
profile_path.mkdir(parents=True, exist_ok=True)
2627
profiler_port = 9012
2728
profile_duration = args.profile_duration
2829
if args.profile:
29-
logger.info(f'starting profiler on port {profiler_port}')
30+
logger.info(f"starting profiler on port {profiler_port}")
3031
_ = xp.start_server(profiler_port)
3132
device0 = xm.xla_device()
3233

33-
logger.info(f'loading flux from {ckpt_id}')
34-
flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None,
35-
text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device0)
34+
logger.info(f"loading flux from {ckpt_id}")
35+
flux_pipe = FluxPipeline.from_pretrained(
36+
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
37+
).to(device0)
3638
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
3739

38-
prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side'
40+
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
3941
width = args.width
4042
height = args.height
4143
guidance = args.guidance
4244
n_steps = 4 if args.schnell else 28
4345

44-
logger.info('starting compilation run...')
46+
logger.info("starting compilation run...")
4547
ts = perf_counter()
4648
with torch.no_grad():
4749
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
48-
prompt=prompt, prompt_2=None, max_sequence_length=512)
50+
prompt=prompt, prompt_2=None, max_sequence_length=512
51+
)
4952
prompt_embeds = prompt_embeds.to(device0)
5053
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
5154

52-
image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
53-
num_inference_steps=28, guidance_scale=guidance, height=height, width=width).images[0]
54-
logger.info(f'compilation took {perf_counter() - ts} sec.')
55-
image.save('/tmp/compile_out.png')
55+
image = flux_pipe(
56+
prompt_embeds=prompt_embeds,
57+
pooled_prompt_embeds=pooled_prompt_embeds,
58+
num_inference_steps=28,
59+
guidance_scale=guidance,
60+
height=height,
61+
width=width,
62+
).images[0]
63+
logger.info(f"compilation took {perf_counter() - ts} sec.")
64+
image.save("/tmp/compile_out.png")
5665

5766
base_seed = 4096 if args.seed is None else args.seed
5867
seed_range = 1000
5968
unique_seed = base_seed + index * seed_range
6069
xm.set_rng_state(seed=unique_seed, device=device0)
6170
times = []
62-
logger.info('starting inference run...')
71+
logger.info("starting inference run...")
6372
for _ in range(args.itters):
6473
ts = perf_counter()
6574
with torch.no_grad():
6675
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
67-
prompt=prompt, prompt_2=None, max_sequence_length=512)
76+
prompt=prompt, prompt_2=None, max_sequence_length=512
77+
)
6878
prompt_embeds = prompt_embeds.to(device0)
6979
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
7080

7181
if args.profile:
7282
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
73-
image = flux_pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
74-
num_inference_steps=n_steps, guidance_scale=guidance, height=height, width=width).images[0]
83+
image = flux_pipe(
84+
prompt_embeds=prompt_embeds,
85+
pooled_prompt_embeds=pooled_prompt_embeds,
86+
num_inference_steps=n_steps,
87+
guidance_scale=guidance,
88+
height=height,
89+
width=width,
90+
).images[0]
7591
inference_time = perf_counter() - ts
7692
if index == 0:
7793
logger.info(f"inference time: {inference_time}")
7894
times.append(inference_time)
79-
logger.info(f'avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.')
80-
image.save(f'/tmp/inference_out-{index}.png')
95+
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
96+
image.save(f"/tmp/inference_out-{index}.png")
8197
if index == 0:
8298
metrics_report = met.metrics_report()
83-
with open(metrics_filepath, 'w+') as fout:
99+
with open(metrics_filepath, "w+") as fout:
84100
fout.write(metrics_report)
85-
logger.info(f'saved metric information as {metrics_filepath}')
101+
logger.info(f"saved metric information as {metrics_filepath}")
102+
86103

87-
if __name__ == '__main__':
104+
if __name__ == "__main__":
88105
parser = ArgumentParser()
89-
parser.add_argument('--schnell', action='store_true', help='run flux schnell instead of dev')
90-
parser.add_argument('--width', type=int, default=1024, help='width of the image to generate')
91-
parser.add_argument('--height', type=int, default=1024, help='height of the image to generate')
92-
parser.add_argument('--guidance', type=float, default=3.5, help='gauidance strentgh for dev')
93-
parser.add_argument('--seed', type=int, default=None, help='seed for inference')
94-
parser.add_argument('--profile', action='store_true', help='enable profiling')
95-
parser.add_argument('--profile-duration', type=int, default=10000, help='duration for profiling in msec.')
96-
parser.add_argument('--itters', type=int, default=15, help='tiems to run inference and get avg time in sec.')
106+
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
107+
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
108+
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
109+
parser.add_argument("--guidance", type=float, default=3.5, help="gauidance strentgh for dev")
110+
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
111+
parser.add_argument("--profile", action="store_true", help="enable profiling")
112+
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
113+
parser.add_argument("--itters", type=int, default=15, help="tiems to run inference and get avg time in sec.")
97114
args = parser.parse_args()
98115
if args.schnell:
99116
ckpt_id = "black-forest-labs/FLUX.1-schnell"
100117
else:
101118
ckpt_id = "black-forest-labs/FLUX.1-dev"
102-
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to('cpu')
103-
xmp.spawn(_main, args=(args, text_pipe, ckpt_id))
119+
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
120+
xmp.spawn(_main, args=(args, text_pipe, ckpt_id))

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,7 +2322,7 @@ def __call__(
23222322
key = apply_rotary_emb(key, image_rotary_emb)
23232323

23242324
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2325-
2325+
23262326
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
23272327
hidden_states = hidden_states.to(query.dtype)
23282328

@@ -2522,9 +2522,9 @@ def __call__(
25222522

25232523
query = apply_rotary_emb(query, image_rotary_emb)
25242524
key = apply_rotary_emb(key, image_rotary_emb)
2525-
2525+
25262526
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2527-
2527+
25282528
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
25292529
hidden_states = hidden_states.to(query.dtype)
25302530

@@ -3503,7 +3503,7 @@ def __call__(
35033503

35043504
query /= math.sqrt(head_dim)
35053505
hidden_states = flash_attention(query, key, value, causal=False)
3506-
3506+
35073507
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
35083508
hidden_states = hidden_states.to(query.dtype)
35093509

@@ -3523,7 +3523,7 @@ def __call__(
35233523
return hidden_states, encoder_hidden_states
35243524
else:
35253525
return hidden_states
3526-
3526+
35273527

35283528
class MochiVaeAttnProcessor2_0:
35293529
r"""

0 commit comments

Comments
 (0)