Skip to content

Commit 19231e0

Browse files
committed
fixes
1 parent b1fc6e1 commit 19231e0

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

run_benchmark.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch.profiler import profile, record_function, ProfilerActivity
55
from utils.benchmark_utils import annotate, create_parser
66
from utils.pipeline_utils import load_pipeline # noqa: E402
7+
from diffusers.utils import load_image
8+
import os
79

810
def _determine_pipe_call_kwargs(args):
911
kwargs = {"max_sequence_length": 256, "guidance_scale": 0.0}
@@ -12,6 +14,7 @@ def _determine_pipe_call_kwargs(args):
1214
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
1315
elif ckpt_id == "black-forest-labs/FLUX.1-Kontext-dev":
1416
kwargs = {"max_sequence_length": 512, "guidance_scale": 2.5}
17+
kwargs.update({"image": load_image(args.image)})
1518
return kwargs
1619

1720
def set_rand_seeds(seed):
@@ -22,12 +25,14 @@ def set_rand_seeds(seed):
2225
def main(args):
2326
set_rand_seeds(args.seed)
2427
pipeline = load_pipeline(args)
28+
if args.ckpt == "black-forest-labs/FLUX.1-Kontext-dev":
29+
assert os.path.exists(args.image)
2530
set_rand_seeds(args.seed)
2631

2732
# warmup
2833
for _ in range(3):
2934
image = pipeline(
30-
args.prompt,
35+
prompt=args.prompt,
3136
num_inference_steps=args.num_inference_steps,
3237
generator=torch.manual_seed(args.seed),
3338
**_determine_pipe_call_kwargs(args)
@@ -38,7 +43,7 @@ def main(args):
3843
for _ in range(10):
3944
begin = time.time()
4045
image = pipeline(
41-
args.prompt,
46+
prompt=args.prompt,
4247
num_inference_steps=args.num_inference_steps,
4348
generator=torch.manual_seed(args.seed),
4449
**_determine_pipe_call_kwargs(args)

utils/pipeline_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import pathlib
33
import torch
4-
from diffusers import FluxPipeline, FluxKontextPipeline
4+
from diffusers import DiffusionPipeline
55
from torch._inductor.package import load_package as inductor_load_package
66
from typing import List, Optional
77
from PIL import Image
@@ -410,7 +410,7 @@ def optimize(pipeline, args):
410410

411411
def load_pipeline(args):
412412
load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16
413-
pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
413+
pipeline = DiffusionPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
414414
pipeline.set_progress_bar_config(disable=True)
415415
pipeline = optimize(pipeline, args)
416416
return pipeline

0 commit comments

Comments
 (0)