@@ -54,14 +54,15 @@ def replace_new_forward(unet):
5454 upsample_block .forward = types .MethodType (cacheupblock2d_forward , upsample_block )
5555
5656
57- def get_input_info (dummy_dict , info = None ):
57+ def get_input_info (dummy_dict , info : str = None , batch_size : int = 1 ):
5858 return_val = [] if info == "profile_shapes" or info == "input_names" else {}
5959
6060 def collect_leaf_keys (d ):
6161 for key , value in d .items ():
6262 if isinstance (value , dict ):
6363 collect_leaf_keys (value )
6464 else :
65+ value = (value [0 ] * batch_size ,) + value [1 :]
6566 if info == "profile_shapes" :
6667 return_val .append ((key , value )) # type: ignore
6768 elif info == "profile_shapes_dict" :
@@ -75,7 +76,7 @@ def collect_leaf_keys(d):
7576 return return_val
7677
7778
78- def complie2trt (onnx_path : Path , engine_path : Path ):
79+ def complie2trt (onnx_path : Path , engine_path : Path , batch_size : int = 1 ):
7980 subdirs = [f for f in onnx_path .iterdir () if f .is_dir ()]
8081 for subdir in subdirs :
8182 if subdir .name not in SDXL_ONNX_CONFIG .keys ():
@@ -86,15 +87,17 @@ def complie2trt(onnx_path: Path, engine_path: Path):
8687 print (f"Building { str (model_path )} " )
8788 build_profile = Profile ()
8889 profile_shapes = get_input_info (
89- SDXL_ONNX_CONFIG [subdir .name ]["dummy_input" ], "profile_shapes"
90+ SDXL_ONNX_CONFIG [subdir .name ]["dummy_input" ], "profile_shapes" , batch_size
9091 )
9192 for input_name , input_shape in profile_shapes :
92- build_profile .add (input_name , input_shape , input_shape , input_shape )
93+ min_input_shape = (2 ,) + input_shape [1 :]
94+ build_profile .add (input_name , min_input_shape , input_shape , input_shape )
9395 block_network = network_from_onnx_path (
94- str (model_path ), flags = [trt .OnnxParserFlag .NATIVE_INSTANCENORM ]
96+ str (model_path ), flags = [trt .OnnxParserFlag .NATIVE_INSTANCENORM ], strongly_typed = True
9597 )
9698 build_config = CreateConfig (
97- fp16 = True ,
99+ builder_optimization_level = 4 ,
100+ tf32 = True ,
98101 profiles = [build_profile ],
99102 )
100103 engine = engine_from_network (
@@ -113,7 +116,7 @@ def get_total_device_memory(unet):
113116 return max_device_memory
114117
115118
116- def load_engines (unet , engine_path : Path ):
119+ def load_engines (unet , engine_path : Path , batch_size : int = 1 ):
117120 unet .engines = {}
118121 for f in engine_path .iterdir ():
119122 if f .is_file ():
@@ -127,9 +130,10 @@ def load_engines(unet, engine_path: Path):
127130 for block_name in unet .engines .keys ():
128131 unet .engines [block_name ].allocate_buffers (
129132 shape_dict = get_input_info (
130- SDXL_ONNX_CONFIG [block_name ]["dummy_input" ], "profile_shapes_dict"
133+ SDXL_ONNX_CONFIG [block_name ]["dummy_input" ], "profile_shapes_dict" , batch_size
131134 ),
132135 device = unet .device ,
136+ batch_size = batch_size ,
133137 )
134138 # TODO: Free and clean up the origin pytorch cuda memory
135139
@@ -216,10 +220,12 @@ def export_onnx(unet, onnx_path: Path):
216220 print (f"{ str (_onnx_file )} alread exists!" )
217221
218222
219- def warm_up (unet ):
223+ def warm_up (unet , batch_size : int = 1 ):
220224 print ("Warming-up TensorRT engines..." )
221225 for name , engine in unet .engines .items ():
222- dummy_input = get_input_info (SDXL_ONNX_CONFIG [name ]["dummy_input" ], "dummy_input" )
226+ dummy_input = get_input_info (
227+ SDXL_ONNX_CONFIG [name ]["dummy_input" ], "dummy_input" , batch_size
228+ )
223229 _ = engine (dummy_input , unet .cuda_stream )
224230
225231
@@ -231,13 +237,13 @@ def teardown(unet):
231237 del unet .cuda_stream
232238
233239
234- def compile (unet , onnx_path : Path , engine_path : Path ):
240+ def compile (unet , onnx_path : Path , engine_path : Path , batch_size : int = 1 ):
235241 onnx_path .mkdir (parents = True , exist_ok = True )
236242 engine_path .mkdir (parents = True , exist_ok = True )
237243
238244 replace_new_forward (unet )
239245 export_onnx (unet , onnx_path )
240- complie2trt (onnx_path , engine_path )
241- load_engines (unet , engine_path )
242- warm_up (unet )
246+ complie2trt (onnx_path , engine_path , batch_size )
247+ load_engines (unet , engine_path , batch_size )
248+ warm_up (unet , batch_size )
243249 unet .use_trt_infer = True
0 commit comments