Skip to content

Commit d47d31e

Browse files
authored
[2.6][Dy2St] Decrease test_train_step time and disable some uts on CPU(#59867) (#59923)
* [Dy2St] Decrease `test_train_step` time (#59867) * [Dy2St] Decrease `test_train_step` time * rename back * [Dy2St] Disable `test_resnet` and `test_build_strategy` on CPU tests (#59742) * [Dy2St] Decrease `test_resnet` time * temp rename to resnetx * increase timeout of test_resnet * remove pir test case * dec timeout * inc timeout * disable on CPU * rename to merge * rename to resnetx * rename to resnet * fix import * fix v2 * fix test_build_strategy * rename to test_build_strategyx * rename back * only,cherry-pick; test=document_fix
1 parent 217cc54 commit d47d31e

File tree

6 files changed

+74
-76
lines changed

6 files changed

+74
-76
lines changed

test/dygraph_to_static/CMakeLists.txt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ if(NOT WITH_GPU)
2525
# We should remove this after fix the performance issue.
2626
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_adam)
2727
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_sgd)
28+
# disable some model test on CPU to avoid timeout
29+
list(REMOVE_ITEM TEST_OPS test_resnet)
30+
list(REMOVE_ITEM TEST_OPS test_build_strategy)
2831
endif()
2932

3033
foreach(TEST_OP ${TEST_OPS})
@@ -43,19 +46,14 @@ set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
4346
set_tests_properties(test_transformer PROPERTIES TIMEOUT 200)
4447
set_tests_properties(test_bmn PROPERTIES TIMEOUT 300)
4548
set_tests_properties(test_bert PROPERTIES TIMEOUT 240)
46-
#set_tests_properties(test_mnist PROPERTIES TIMEOUT 120)
47-
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120)
4849

4950
if(NOT WIN32)
5051
set_tests_properties(test_tsm PROPERTIES TIMEOUT 900)
51-
set_tests_properties(test_resnet PROPERTIES TIMEOUT 300)
5252
endif()
5353

5454
if(APPLE)
5555
set_tests_properties(test_bmn PROPERTIES TIMEOUT 300)
56-
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 300)
5756
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 300)
58-
set_tests_properties(test_resnet PROPERTIES TIMEOUT 300)
5957
endif()
6058

6159
if(WITH_GPU)

test/dygraph_to_static/dygraph_to_static_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,23 +381,23 @@ def test_legacy_and_pt_and_pir(fn):
381381
return fn
382382

383383

384+
# Some decorators for save CI time
384385
def test_default_mode_only(fn):
385386
# Some unittests has high time complexity, we only test them with default mode
386387
fn = set_to_static_mode(ToStaticMode.SOT)(fn)
387388
fn = set_ir_mode(IrMode.PT)(fn)
388389
return fn
389390

390391

391-
def test_sot_with_pir_only(fn):
392+
def test_default_and_pir(fn):
393+
# Some unittests has high time complexity, we only test them with default mode
392394
fn = set_to_static_mode(ToStaticMode.SOT)(fn)
393-
fn = set_ir_mode(IrMode.PIR)(fn)
395+
fn = set_ir_mode(IrMode.PT | IrMode.PIR)(fn)
394396
return fn
395397

396398

397-
def test_default_and_pir(fn):
398-
# Some unittests has high time complexity, we only test them with default mode
399+
def test_sot_mgs0_only(fn):
399400
fn = set_to_static_mode(ToStaticMode.SOT)(fn)
400-
fn = set_ir_mode(IrMode.PT | IrMode.PIR)(fn)
401401
return fn
402402

403403

test/dygraph_to_static/test_build_strategy.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
from dygraph_to_static_utils import (
1919
Dy2StTestBase,
2020
enable_to_static_guard,
21-
test_ast_only,
22-
test_default_and_pir,
23-
test_pt_only,
21+
test_default_mode_only,
22+
test_legacy_and_pt_and_pir,
2423
)
2524
from test_resnet import ResNetHelper
2625

@@ -36,7 +35,7 @@ def setUp(self):
3635
self.build_strategy.enable_addto = True
3736
self.resnet_helper = ResNetHelper()
3837
# NOTE: for enable_addto
39-
paddle.base.set_flags({"FLAGS_max_inplace_grad_add": 8})
38+
paddle.set_flags({"FLAGS_max_inplace_grad_add": 8})
4039

4140
def train(self, to_static):
4241
with enable_to_static_guard(to_static):
@@ -67,8 +66,7 @@ def verify_predict(self):
6766
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
6867
)
6968

70-
@test_ast_only
71-
@test_pt_only
69+
@test_default_mode_only
7270
def test_resnet(self):
7371
static_loss = self.train(to_static=True)
7472
dygraph_loss = self.train(to_static=False)
@@ -80,19 +78,18 @@ def test_resnet(self):
8078
)
8179
self.verify_predict()
8280

83-
@test_ast_only
84-
@test_pt_only
81+
@test_default_mode_only
8582
def test_in_static_mode_mkldnn(self):
86-
paddle.base.set_flags({'FLAGS_use_mkldnn': True})
83+
paddle.set_flags({'FLAGS_use_mkldnn': True})
8784
try:
8885
if paddle.base.core.is_compiled_with_mkldnn():
8986
self.resnet_helper.train(True, self.build_strategy)
9087
finally:
91-
paddle.base.set_flags({'FLAGS_use_mkldnn': False})
88+
paddle.set_flags({'FLAGS_use_mkldnn': False})
9289

9390

9491
class TestError(Dy2StTestBase):
95-
@test_default_and_pir
92+
@test_legacy_and_pt_and_pir
9693
def test_type_error(self):
9794
def foo(x):
9895
out = x + 1

test/dygraph_to_static/test_resnet.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import numpy as np
2222
from dygraph_to_static_utils import (
2323
Dy2StTestBase,
24+
enable_to_static_guard,
25+
static_guard,
2426
test_default_and_pir,
2527
)
2628
from predictor_utils import PredictorTools
@@ -247,32 +249,31 @@ def __len__(self):
247249
return len(self.img)
248250

249251

250-
class TestResnet(Dy2StTestBase):
251-
def setUp(self):
252+
class ResNetHelper:
253+
def __init__(self):
252254
self.temp_dir = tempfile.TemporaryDirectory()
253255

254256
self.model_save_dir = os.path.join(self.temp_dir.name, "./inference")
255257
self.model_save_prefix = os.path.join(
256-
self.temp_dir.name, "./inference/resnet_v2"
258+
self.temp_dir.name, "./inference/resnet"
257259
)
258260
self.model_filename = (
259-
"resnet_v2" + paddle.jit.translated_layer.INFER_MODEL_SUFFIX
261+
"resnet" + paddle.jit.translated_layer.INFER_MODEL_SUFFIX
260262
)
261263
self.params_filename = (
262-
"resnet_v2" + paddle.jit.translated_layer.INFER_PARAMS_SUFFIX
264+
"resnet" + paddle.jit.translated_layer.INFER_PARAMS_SUFFIX
263265
)
264266
self.dy_state_dict_save_path = os.path.join(
265-
self.temp_dir.name, "./resnet_v2.dygraph"
267+
self.temp_dir.name, "./resnet.dygraph"
266268
)
267269

268-
def tearDown(self):
270+
def __del__(self):
269271
self.temp_dir.cleanup()
270272

271-
def do_train(self, to_static):
273+
def train(self, to_static, build_strategy=None):
272274
"""
273275
Tests model decorated by `dygraph_to_static_output` in static graph mode. For users, the model is defined in dygraph mode and trained in static graph mode.
274276
"""
275-
paddle.disable_static(place)
276277
np.random.seed(SEED)
277278
paddle.seed(SEED)
278279
paddle.framework.random._manual_program_seed(SEED)
@@ -284,7 +285,7 @@ def do_train(self, to_static):
284285
dataset, batch_size=batch_size, drop_last=True
285286
)
286287

287-
resnet = paddle.jit.to_static(ResNet())
288+
resnet = paddle.jit.to_static(ResNet(), build_strategy=build_strategy)
288289
optimizer = optimizer_setting(parameter_list=resnet.parameters())
289290

290291
for epoch in range(epoch_num):
@@ -350,59 +351,55 @@ def do_train(self, to_static):
350351
self.dy_state_dict_save_path + '.pdparams',
351352
)
352353
break
353-
paddle.enable_static()
354354

355355
return total_loss.numpy()
356356

357357
def predict_dygraph(self, data):
358-
paddle.jit.enable_to_static(False)
359-
paddle.disable_static(place)
360-
resnet = paddle.jit.to_static(ResNet())
358+
with enable_to_static_guard(False):
359+
resnet = paddle.jit.to_static(ResNet())
361360

362-
model_dict = paddle.load(self.dy_state_dict_save_path + '.pdparams')
363-
resnet.set_dict(model_dict)
364-
resnet.eval()
361+
model_dict = paddle.load(self.dy_state_dict_save_path + '.pdparams')
362+
resnet.set_dict(model_dict)
363+
resnet.eval()
365364

366-
pred_res = resnet(
367-
paddle.to_tensor(
368-
data=data, dtype=None, place=None, stop_gradient=True
365+
pred_res = resnet(
366+
paddle.to_tensor(
367+
data=data, dtype=None, place=None, stop_gradient=True
368+
)
369369
)
370-
)
371370

372371
ret = pred_res.numpy()
373-
paddle.enable_static()
374372
return ret
375373

376374
def predict_static(self, data):
377-
exe = paddle.static.Executor(place)
378-
[
379-
inference_program,
380-
feed_target_names,
381-
fetch_targets,
382-
] = paddle.static.load_inference_model(
383-
self.model_save_dir,
384-
executor=exe,
385-
model_filename=self.model_filename,
386-
params_filename=self.params_filename,
387-
)
375+
with static_guard():
376+
exe = paddle.static.Executor(place)
377+
[
378+
inference_program,
379+
feed_target_names,
380+
fetch_targets,
381+
] = paddle.static.load_inference_model(
382+
self.model_save_dir,
383+
executor=exe,
384+
model_filename=self.model_filename,
385+
params_filename=self.params_filename,
386+
)
388387

389-
pred_res = exe.run(
390-
inference_program,
391-
feed={feed_target_names[0]: data},
392-
fetch_list=fetch_targets,
393-
)
388+
pred_res = exe.run(
389+
inference_program,
390+
feed={feed_target_names[0]: data},
391+
fetch_list=fetch_targets,
392+
)
394393

395-
return pred_res[0]
394+
return pred_res[0]
396395

397396
def predict_dygraph_jit(self, data):
398-
paddle.disable_static(place)
399397
resnet = paddle.jit.load(self.model_save_prefix)
400398
resnet.eval()
401399

402400
pred_res = resnet(data)
403401

404402
ret = pred_res.numpy()
405-
paddle.enable_static()
406403
return ret
407404

408405
def predict_analysis_inference(self, data):
@@ -415,16 +412,21 @@ def predict_analysis_inference(self, data):
415412
(out,) = output()
416413
return out
417414

415+
416+
class TestResnet(Dy2StTestBase):
417+
def setUp(self):
418+
self.resnet_helper = ResNetHelper()
419+
418420
def train(self, to_static):
419-
paddle.jit.enable_to_static(to_static)
420-
return self.do_train(to_static)
421+
with enable_to_static_guard(to_static):
422+
return self.resnet_helper.train(to_static)
421423

422424
def verify_predict(self):
423425
image = np.random.random([1, 3, 224, 224]).astype('float32')
424-
dy_pre = self.predict_dygraph(image)
425-
st_pre = self.predict_static(image)
426-
dy_jit_pre = self.predict_dygraph_jit(image)
427-
predictor_pre = self.predict_analysis_inference(image)
426+
dy_pre = self.resnet_helper.predict_dygraph(image)
427+
st_pre = self.resnet_helper.predict_static(image)
428+
dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
429+
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
428430
np.testing.assert_allclose(
429431
dy_pre,
430432
st_pre,
@@ -455,7 +457,7 @@ def test_resnet(self):
455457
err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}',
456458
)
457459
# TODO(@xiongkun): open after save / load supported in pir.
458-
if not paddle.base.framework.use_pir_api():
460+
if not paddle.framework.use_pir_api():
459461
self.verify_predict()
460462

461463
@test_default_and_pir
@@ -474,12 +476,12 @@ def test_resnet_composite(self):
474476

475477
@test_default_and_pir
476478
def test_in_static_mode_mkldnn(self):
477-
paddle.base.set_flags({'FLAGS_use_mkldnn': True})
479+
paddle.set_flags({'FLAGS_use_mkldnn': True})
478480
try:
479481
if paddle.base.core.is_compiled_with_mkldnn():
480482
self.train(to_static=True)
481483
finally:
482-
paddle.base.set_flags({'FLAGS_use_mkldnn': False})
484+
paddle.set_flags({'FLAGS_use_mkldnn': False})
483485

484486

485487
if __name__ == '__main__':

test/dygraph_to_static/test_seq2seq.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import numpy as np
2121
from dygraph_to_static_utils import (
2222
Dy2StTestBase,
23-
ToStaticMode,
24-
set_to_static_mode,
2523
test_legacy_only,
24+
test_sot_mgs0_only,
2625
)
2726
from seq2seq_dygraph_model import AttentionModel, BaseModel
2827
from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter
@@ -239,13 +238,13 @@ def _test_predict(self, attn_model=False):
239238
msg=f"\npred_dygraph = {pred_dygraph} \npred_static = {pred_static}",
240239
)
241240

242-
@set_to_static_mode(ToStaticMode.SOT)
241+
@test_sot_mgs0_only
243242
@test_legacy_only
244243
def test_base_model(self):
245244
self._test_train(attn_model=False)
246245
self._test_predict(attn_model=False)
247246

248-
@set_to_static_mode(ToStaticMode.SOT)
247+
@test_sot_mgs0_only
249248
@test_legacy_only
250249
def test_attn_model(self):
251250
self._test_train(attn_model=True)

test/dygraph_to_static/test_train_step.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from functools import partial
1818

1919
import numpy as np
20-
from dygraph_to_static_utils import Dy2StTestBase
20+
from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_pt_only
2121

2222
import paddle
2323

@@ -77,6 +77,8 @@ def get_train_step_losses(self, func, steps):
7777
losses.append(loss)
7878
return losses
7979

80+
@test_ast_only
81+
@test_pt_only
8082
def test_train_step(self):
8183
reset_seed()
8284
dygraph_losses = self.get_train_step_losses(

0 commit comments

Comments
 (0)