21
21
import numpy as np
22
22
from dygraph_to_static_utils import (
23
23
Dy2StTestBase ,
24
+ enable_to_static_guard ,
25
+ static_guard ,
24
26
test_default_and_pir ,
25
27
)
26
28
from predictor_utils import PredictorTools
@@ -247,32 +249,31 @@ def __len__(self):
247
249
return len (self .img )
248
250
249
251
250
- class TestResnet ( Dy2StTestBase ) :
251
- def setUp (self ):
252
+ class ResNetHelper :
253
+ def __init__ (self ):
252
254
self .temp_dir = tempfile .TemporaryDirectory ()
253
255
254
256
self .model_save_dir = os .path .join (self .temp_dir .name , "./inference" )
255
257
self .model_save_prefix = os .path .join (
256
- self .temp_dir .name , "./inference/resnet_v2 "
258
+ self .temp_dir .name , "./inference/resnet "
257
259
)
258
260
self .model_filename = (
259
- "resnet_v2 " + paddle .jit .translated_layer .INFER_MODEL_SUFFIX
261
+ "resnet " + paddle .jit .translated_layer .INFER_MODEL_SUFFIX
260
262
)
261
263
self .params_filename = (
262
- "resnet_v2 " + paddle .jit .translated_layer .INFER_PARAMS_SUFFIX
264
+ "resnet " + paddle .jit .translated_layer .INFER_PARAMS_SUFFIX
263
265
)
264
266
self .dy_state_dict_save_path = os .path .join (
265
- self .temp_dir .name , "./resnet_v2 .dygraph"
267
+ self .temp_dir .name , "./resnet .dygraph"
266
268
)
267
269
268
- def tearDown (self ):
270
+ def __del__ (self ):
269
271
self .temp_dir .cleanup ()
270
272
271
- def do_train (self , to_static ):
273
+ def train (self , to_static , build_strategy = None ):
272
274
"""
273
275
Tests model decorated by `dygraph_to_static_output` in static graph mode. For users, the model is defined in dygraph mode and trained in static graph mode.
274
276
"""
275
- paddle .disable_static (place )
276
277
np .random .seed (SEED )
277
278
paddle .seed (SEED )
278
279
paddle .framework .random ._manual_program_seed (SEED )
@@ -284,7 +285,7 @@ def do_train(self, to_static):
284
285
dataset , batch_size = batch_size , drop_last = True
285
286
)
286
287
287
- resnet = paddle .jit .to_static (ResNet ())
288
+ resnet = paddle .jit .to_static (ResNet (), build_strategy = build_strategy )
288
289
optimizer = optimizer_setting (parameter_list = resnet .parameters ())
289
290
290
291
for epoch in range (epoch_num ):
@@ -350,59 +351,55 @@ def do_train(self, to_static):
350
351
self .dy_state_dict_save_path + '.pdparams' ,
351
352
)
352
353
break
353
- paddle .enable_static ()
354
354
355
355
return total_loss .numpy ()
356
356
357
357
def predict_dygraph (self , data ):
358
- paddle .jit .enable_to_static (False )
359
- paddle .disable_static (place )
360
- resnet = paddle .jit .to_static (ResNet ())
358
+ with enable_to_static_guard (False ):
359
+ resnet = paddle .jit .to_static (ResNet ())
361
360
362
- model_dict = paddle .load (self .dy_state_dict_save_path + '.pdparams' )
363
- resnet .set_dict (model_dict )
364
- resnet .eval ()
361
+ model_dict = paddle .load (self .dy_state_dict_save_path + '.pdparams' )
362
+ resnet .set_dict (model_dict )
363
+ resnet .eval ()
365
364
366
- pred_res = resnet (
367
- paddle .to_tensor (
368
- data = data , dtype = None , place = None , stop_gradient = True
365
+ pred_res = resnet (
366
+ paddle .to_tensor (
367
+ data = data , dtype = None , place = None , stop_gradient = True
368
+ )
369
369
)
370
- )
371
370
372
371
ret = pred_res .numpy ()
373
- paddle .enable_static ()
374
372
return ret
375
373
376
374
def predict_static (self , data ):
377
- exe = paddle .static .Executor (place )
378
- [
379
- inference_program ,
380
- feed_target_names ,
381
- fetch_targets ,
382
- ] = paddle .static .load_inference_model (
383
- self .model_save_dir ,
384
- executor = exe ,
385
- model_filename = self .model_filename ,
386
- params_filename = self .params_filename ,
387
- )
375
+ with static_guard ():
376
+ exe = paddle .static .Executor (place )
377
+ [
378
+ inference_program ,
379
+ feed_target_names ,
380
+ fetch_targets ,
381
+ ] = paddle .static .load_inference_model (
382
+ self .model_save_dir ,
383
+ executor = exe ,
384
+ model_filename = self .model_filename ,
385
+ params_filename = self .params_filename ,
386
+ )
388
387
389
- pred_res = exe .run (
390
- inference_program ,
391
- feed = {feed_target_names [0 ]: data },
392
- fetch_list = fetch_targets ,
393
- )
388
+ pred_res = exe .run (
389
+ inference_program ,
390
+ feed = {feed_target_names [0 ]: data },
391
+ fetch_list = fetch_targets ,
392
+ )
394
393
395
- return pred_res [0 ]
394
+ return pred_res [0 ]
396
395
397
396
def predict_dygraph_jit (self , data ):
398
- paddle .disable_static (place )
399
397
resnet = paddle .jit .load (self .model_save_prefix )
400
398
resnet .eval ()
401
399
402
400
pred_res = resnet (data )
403
401
404
402
ret = pred_res .numpy ()
405
- paddle .enable_static ()
406
403
return ret
407
404
408
405
def predict_analysis_inference (self , data ):
@@ -415,16 +412,21 @@ def predict_analysis_inference(self, data):
415
412
(out ,) = output ()
416
413
return out
417
414
415
+
416
+ class TestResnet (Dy2StTestBase ):
417
+ def setUp (self ):
418
+ self .resnet_helper = ResNetHelper ()
419
+
418
420
def train (self , to_static ):
419
- paddle . jit . enable_to_static (to_static )
420
- return self .do_train (to_static )
421
+ with enable_to_static_guard (to_static ):
422
+ return self .resnet_helper . train (to_static )
421
423
422
424
def verify_predict (self ):
423
425
image = np .random .random ([1 , 3 , 224 , 224 ]).astype ('float32' )
424
- dy_pre = self .predict_dygraph (image )
425
- st_pre = self .predict_static (image )
426
- dy_jit_pre = self .predict_dygraph_jit (image )
427
- predictor_pre = self .predict_analysis_inference (image )
426
+ dy_pre = self .resnet_helper . predict_dygraph (image )
427
+ st_pre = self .resnet_helper . predict_static (image )
428
+ dy_jit_pre = self .resnet_helper . predict_dygraph_jit (image )
429
+ predictor_pre = self .resnet_helper . predict_analysis_inference (image )
428
430
np .testing .assert_allclose (
429
431
dy_pre ,
430
432
st_pre ,
@@ -455,7 +457,7 @@ def test_resnet(self):
455
457
err_msg = f'static_loss: { static_loss } \n dygraph_loss: { dygraph_loss } ' ,
456
458
)
457
459
# TODO(@xiongkun): open after save / load supported in pir.
458
- if not paddle .base . framework .use_pir_api ():
460
+ if not paddle .framework .use_pir_api ():
459
461
self .verify_predict ()
460
462
461
463
@test_default_and_pir
@@ -474,12 +476,12 @@ def test_resnet_composite(self):
474
476
475
477
@test_default_and_pir
476
478
def test_in_static_mode_mkldnn (self ):
477
- paddle .base . set_flags ({'FLAGS_use_mkldnn' : True })
479
+ paddle .set_flags ({'FLAGS_use_mkldnn' : True })
478
480
try :
479
481
if paddle .base .core .is_compiled_with_mkldnn ():
480
482
self .train (to_static = True )
481
483
finally :
482
- paddle .base . set_flags ({'FLAGS_use_mkldnn' : False })
484
+ paddle .set_flags ({'FLAGS_use_mkldnn' : False })
483
485
484
486
485
487
if __name__ == '__main__' :
0 commit comments