@@ -56,11 +56,21 @@ def parse_arguments():
56
56
type = str ,
57
57
default = "paddle" ,
58
58
# Note(zhoushunjie): Will support 'tensorrt', 'paddle-tensorrt' soon.
59
+ choices = ["onnx_runtime" , "paddle" , "paddlelite" ],
60
+ help = "The inference runtime backend of unet model and text encoder model." ,
61
+ )
62
+ parser .add_argument (
63
+ "--device" ,
64
+ type = str ,
65
+ default = "gpu" ,
66
+ # Note(shentanyue): Will support more devices.
59
67
choices = [
60
- "onnx_runtime" ,
61
- "paddle" ,
68
+ "cpu" ,
69
+ "gpu" ,
70
+ "huawei_ascend_npu" ,
71
+ "kunlunxin_xpu" ,
62
72
],
63
- help = "The inference runtime backend of unet model and text encoder model ." ,
73
+ help = "The inference runtime device of models ." ,
64
74
)
65
75
parser .add_argument (
66
76
"--image_path" , default = "fd_astronaut_rides_horse.png" , help = "The model directory of diffusion_model."
@@ -123,6 +133,25 @@ def create_paddle_inference_runtime(
123
133
return fd .Runtime (option )
124
134
125
135
136
+ def create_paddle_lite_runtime (model_dir , model_prefix , device = "cpu" , device_id = 0 ):
137
+ option = fd .RuntimeOption ()
138
+ option .use_lite_backend ()
139
+ if device == "huawei_ascend_npu" :
140
+ option .use_cann ()
141
+ option .set_lite_nnadapter_device_names (["huawei_ascend_npu" ])
142
+ option .set_lite_nnadapter_model_cache_dir (os .path .join (model_dir , model_prefix ))
143
+ option .set_lite_nnadapter_context_properties ("HUAWEI_ASCEND_NPU_SELECTED_DEVICE_IDS={}" .format (device_id ))
144
+ elif device == "kunlunxin_xpu" :
145
+ # TODO(shentanyue): Add kunlunxin_xpu code
146
+ pass
147
+ else :
148
+ pass
149
+ model_file = os .path .join (model_dir , model_prefix , "inference.pdmodel" )
150
+ params_file = os .path .join (model_dir , model_prefix , "inference.pdiparams" )
151
+ option .set_model_path (model_file , params_file )
152
+ return fd .Runtime (option )
153
+
154
+
126
155
def create_trt_runtime (model_dir , model_prefix , model_format , workspace = (1 << 31 ), dynamic_shape = None , device_id = 0 ):
127
156
option = fd .RuntimeOption ()
128
157
option .use_trt_backend ()
@@ -210,42 +239,45 @@ def get_scheduler(args):
210
239
}
211
240
212
241
# 4. Init runtime
242
+ device_id = args .device_id
243
+ if args .device == "cpu" :
244
+ device_id = - 1
213
245
if args .backend == "onnx_runtime" :
214
246
text_encoder_runtime = create_ort_runtime (
215
- args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = args . device_id
247
+ args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = device_id
216
248
)
217
249
vae_decoder_runtime = create_ort_runtime (
218
- args .model_dir , args .vae_decoder_model_prefix , args .model_format , device_id = args . device_id
250
+ args .model_dir , args .vae_decoder_model_prefix , args .model_format , device_id = device_id
219
251
)
220
252
vae_encoder_runtime = create_ort_runtime (
221
- args .model_dir , args .vae_encoder_model_prefix , args .model_format , device_id = args . device_id
253
+ args .model_dir , args .vae_encoder_model_prefix , args .model_format , device_id = device_id
222
254
)
223
255
start = time .time ()
224
256
unet_runtime = create_ort_runtime (
225
- args .model_dir , args .unet_model_prefix , args .model_format , device_id = args . device_id
257
+ args .model_dir , args .unet_model_prefix , args .model_format , device_id = device_id
226
258
)
227
259
print (f"Spend { time .time () - start : .2f} s to load unet model." )
228
260
elif args .backend == "paddle" or args .backend == "paddle-tensorrt" :
229
261
use_trt = True if args .backend == "paddle-tensorrt" else False
230
262
# Note(zhoushunjie): Will change to paddle runtime later
231
263
text_encoder_runtime = create_ort_runtime (
232
- args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = args . device_id
264
+ args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = device_id
233
265
)
234
266
vae_decoder_runtime = create_paddle_inference_runtime (
235
267
args .model_dir ,
236
268
args .vae_decoder_model_prefix ,
237
269
use_trt ,
238
270
vae_decoder_dynamic_shape ,
239
271
use_fp16 = args .use_fp16 ,
240
- device_id = args . device_id ,
272
+ device_id = device_id ,
241
273
)
242
274
vae_encoder_runtime = create_paddle_inference_runtime (
243
275
args .model_dir ,
244
276
args .vae_encoder_model_prefix ,
245
277
use_trt ,
246
278
vae_encoder_dynamic_shape ,
247
279
use_fp16 = args .use_fp16 ,
248
- device_id = args . device_id ,
280
+ device_id = device_id ,
249
281
)
250
282
start = time .time ()
251
283
unet_runtime = create_paddle_inference_runtime (
@@ -254,7 +286,7 @@ def get_scheduler(args):
254
286
use_trt ,
255
287
unet_dynamic_shape ,
256
288
use_fp16 = args .use_fp16 ,
257
- device_id = args . device_id ,
289
+ device_id = device_id ,
258
290
)
259
291
print (f"Spend { time .time () - start : .2f} s to load unet model." )
260
292
elif args .backend == "tensorrt" :
@@ -265,23 +297,38 @@ def get_scheduler(args):
265
297
args .model_format ,
266
298
workspace = (1 << 30 ),
267
299
dynamic_shape = vae_decoder_dynamic_shape ,
268
- device_id = args . device_id ,
300
+ device_id = device_id ,
269
301
)
270
302
vae_encoder_runtime = create_trt_runtime (
271
303
args .model_dir ,
272
304
args .vae_encoder_model_prefix ,
273
305
args .model_format ,
274
306
workspace = (1 << 30 ),
275
307
dynamic_shape = vae_encoder_dynamic_shape ,
276
- device_id = args . device_id ,
308
+ device_id = device_id ,
277
309
)
278
310
start = time .time ()
279
311
unet_runtime = create_trt_runtime (
280
312
args .model_dir ,
281
313
args .unet_model_prefix ,
282
314
args .model_format ,
283
315
dynamic_shape = unet_dynamic_shape ,
284
- device_id = args .device_id ,
316
+ device_id = device_id ,
317
+ )
318
+ print (f"Spend { time .time () - start : .2f} s to load unet model." )
319
+ elif args .backend == "paddlelite" :
320
+ text_encoder_runtime = create_paddle_lite_runtime (
321
+ args .model_dir , args .text_encoder_model_prefix , device = args .device , device_id = device_id
322
+ )
323
+ vae_decoder_runtime = create_paddle_lite_runtime (
324
+ args .model_dir , args .vae_decoder_model_prefix , device = args .device , device_id = device_id
325
+ )
326
+ vae_encoder_runtime = create_paddle_lite_runtime (
327
+ args .model_dir , args .vae_encoder_model_prefix , device = args .device , device_id = device_id
328
+ )
329
+ start = time .time ()
330
+ unet_runtime = create_paddle_lite_runtime (
331
+ args .model_dir , args .unet_model_prefix , device = args .device , device_id = device_id
285
332
)
286
333
print (f"Spend { time .time () - start : .2f} s to load unet model." )
287
334
0 commit comments