44from torch .profiler import profile , record_function , ProfilerActivity
55from utils .benchmark_utils import annotate , create_parser
66from utils .pipeline_utils import load_pipeline # noqa: E402
7+ from diffusers .utils import load_image
8+ import os
79
810def _determine_pipe_call_kwargs (args ):
911 kwargs = {"max_sequence_length" : 256 , "guidance_scale" : 0.0 }
@@ -12,6 +14,7 @@ def _determine_pipe_call_kwargs(args):
1214 kwargs = {"max_sequence_length" : 512 , "guidance_scale" : 3.5 }
1315 elif ckpt_id == "black-forest-labs/FLUX.1-Kontext-dev" :
1416 kwargs = {"max_sequence_length" : 512 , "guidance_scale" : 2.5 }
17+ kwargs .update ({"image" : load_image (args .image )})
1518 return kwargs
1619
1720def set_rand_seeds (seed ):
@@ -22,12 +25,14 @@ def set_rand_seeds(seed):
2225def main (args ):
2326 set_rand_seeds (args .seed )
2427 pipeline = load_pipeline (args )
28+ if args .ckpt == "black-forest-labs/FLUX.1-Kontext-dev" :
29+ assert os .path .exists (args .image )
2530 set_rand_seeds (args .seed )
2631
2732 # warmup
2833 for _ in range (3 ):
2934 image = pipeline (
30- args .prompt ,
35+ prompt = args .prompt ,
3136 num_inference_steps = args .num_inference_steps ,
3237 generator = torch .manual_seed (args .seed ),
3338 ** _determine_pipe_call_kwargs (args )
@@ -38,7 +43,7 @@ def main(args):
3843 for _ in range (10 ):
3944 begin = time .time ()
4045 image = pipeline (
41- args .prompt ,
46+ prompt = args .prompt ,
4247 num_inference_steps = args .num_inference_steps ,
4348 generator = torch .manual_seed (args .seed ),
4449 ** _determine_pipe_call_kwargs (args )
0 commit comments