Skip to content

Commit 5a49b30

Browse files
committed
More configuration + image gen script using cached model
1 parent b6bbbcb commit 5a49b30

File tree

4 files changed

+45
-6
lines changed

4 files changed

+45
-6
lines changed

gen_image.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import random
2+
import time
3+
import torch
4+
from torch.profiler import profile, record_function, ProfilerActivity
5+
from utils.benchmark_utils import annotate, create_parser
6+
from utils.pipeline_utils import load_pipeline # noqa: E402
7+
8+
9+
def set_rand_seeds(seed):
10+
random.seed(seed)
11+
torch.manual_seed(seed)
12+
13+
14+
def main(args):
15+
pipeline = load_pipeline(args)
16+
set_rand_seeds(args.seed)
17+
18+
image = pipeline(
19+
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
20+
).images[0]
21+
image.save(args.output_file)
22+
23+
24+
if __name__ == "__main__":
25+
parser = create_parser()
26+
args = parser.parse_args()
27+
# use the cached model to minimize latency
28+
args.use_cached_model = True
29+
main(args)

run_benchmark.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,15 @@
66
from utils.pipeline_utils import load_pipeline # noqa: E402
77

88

9+
def set_rand_seeds(seed):
10+
random.seed(seed)
11+
torch.manual_seed(seed)
12+
13+
914
def main(args):
15+
set_rand_seeds(args.seed)
1016
pipeline = load_pipeline(args)
17+
set_rand_seeds(args.seed)
1118

1219
# warmup
1320
for _ in range(3):
@@ -52,10 +59,6 @@ def main(args):
5259

5360

5461
if __name__ == "__main__":
55-
seed = 42
56-
random.seed(seed)
57-
torch.manual_seed(seed)
58-
5962
parser = create_parser()
6063
args = parser.parse_args()
6164
main(args)

utils/benchmark_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@ def create_parser():
1515
help="Text prompt")
1616
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"),
1717
help="Cache directory for storing exported models")
18+
parser.add_argument("--use-cached-model", action="store_true",
19+
help="Attempt to use cached model only (don't re-export)")
1820
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda",
1921
help="Device to use")
2022
parser.add_argument("--num_inference_steps", type=int, default=4,
2123
help="Number of denoising steps")
2224
parser.add_argument("--output-file", type=str, default="output.png",
2325
help="Output image file path")
26+
parser.add_argument("--seed", type=int, default=42,
27+
help="Random seed to use")
2428
# file path for optional output PyTorch Profiler trace
2529
parser.add_argument("--trace-file", type=str, default=None,
2630
help="Output PyTorch Profiler trace file path")

utils/pipeline_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,11 @@ def optimize(pipeline, args):
375375
if args.compile_export_mode == "compile":
376376
pipeline = use_compile(pipeline)
377377
elif args.compile_export_mode == "export_aoti":
378-
# NB: Using a cached export + AOTI model is not supported yet
379-
pipeline = use_export_aoti(pipeline, cache_dir=args.cache_dir, serialize=True)
378+
pipeline = use_export_aoti(
379+
pipeline,
380+
cache_dir=args.cache_dir,
381+
serialize=(not args.use_cached_model),
382+
)
380383
elif args.compile_export_mode == "disabled":
381384
pass
382385
else:

0 commit comments

Comments
 (0)