@@ -478,81 +478,82 @@ def test_with_place(place, data_layout, shape):
478
478
]
479
479
ground_truth = {name : var_dict [name ] for name in var_names }
480
480
481
- program = base .Program ()
482
- with base .program_guard (program ):
483
- block = program .global_block ()
484
- for name in ground_truth :
485
- block .create_var (
486
- name = name , dtype = "float32" , shape = ground_truth [name ].shape
481
+ with paddle .pir_utils .OldIrGuard ():
482
+ program = base .Program ()
483
+ with base .program_guard (program ):
484
+ block = program .global_block ()
485
+ for name in ground_truth :
486
+ block .create_var (
487
+ name = name , dtype = "float32" , shape = ground_truth [name ].shape
488
+ )
489
+ inputs = {
490
+ "X" : block .var ("x" ),
491
+ "Scale" : block .var ("scale" ),
492
+ "Bias" : block .var ("bias" ),
493
+ "Mean" : block .var ("mean" ),
494
+ "Variance" : block .var ("variance" ),
495
+ }
496
+ attrs = {
497
+ "epsilon" : epsilon ,
498
+ "is_test" : False ,
499
+ "data_layout" : data_layout ,
500
+ "use_mkldnn" : False ,
501
+ "fuse_with_relu" : self .fuse_with_relu ,
502
+ "use_global_stats" : self .use_global_stats ,
503
+ }
504
+ if self .use_momentum_variable :
505
+ inputs ["MomentumTensor" ] = block .var ("momentum_var" )
506
+ else :
507
+ attrs ["momentum" ] = momentum
508
+
509
+ outputs = {
510
+ "Y" : block .var ("y" ),
511
+ "MeanOut" : block .var ("mean" ), # share memory
512
+ "VarianceOut" : block .var ("variance" ), # share memory
513
+ "SavedMean" : block .var ("saved_mean" ),
514
+ "SavedVariance" : block .var ("saved_variance" ),
515
+ }
516
+ block .create_var (name = "reserve_space" , dtype = "float32" )
517
+ outputs ["ReserveSpace" ] = block .var ("reserve_space" )
518
+ bn_op = block .append_op (
519
+ type = "batch_norm" , inputs = inputs , outputs = outputs , attrs = attrs
520
+ )
521
+ block .create_var (name = "y@GRAD" , dtype = "float32" , shape = y .shape )
522
+
523
+ # generate backward op_desc
524
+ grad_op_desc_list , op_grad_to_var = core .get_grad_op_desc (
525
+ bn_op .desc , self .no_grad_set , []
526
+ )
527
+ grad_op_desc = grad_op_desc_list [0 ]
528
+ new_op_desc = block .desc .append_op ()
529
+ new_op_desc .copy_from (grad_op_desc )
530
+ for var_name in grad_op_desc .output_arg_names ():
531
+ block .desc .var (var_name .encode ("ascii" ))
532
+ grad_op_desc .infer_var_type (block .desc )
533
+ grad_op_desc .infer_shape (block .desc )
534
+ for arg in grad_op_desc .output_arg_names ():
535
+ grad_var = block .desc .find_var (arg .encode ("ascii" ))
536
+ grad_var .set_dtype (core .VarDesc .VarType .FP32 )
537
+
538
+ program ._sync_with_cpp ()
539
+
540
+ exe = base .Executor (place )
541
+ out = exe .run (
542
+ program ,
543
+ feed = {
544
+ name : var_dict [name ]
545
+ for name in [
546
+ "x" ,
547
+ "scale" ,
548
+ "bias" ,
549
+ "mean" ,
550
+ "variance" ,
551
+ "y@GRAD" ,
552
+ "momentum_var" ,
553
+ ]
554
+ },
555
+ fetch_list = self .fetch_list ,
487
556
)
488
- inputs = {
489
- "X" : block .var ("x" ),
490
- "Scale" : block .var ("scale" ),
491
- "Bias" : block .var ("bias" ),
492
- "Mean" : block .var ("mean" ),
493
- "Variance" : block .var ("variance" ),
494
- }
495
- attrs = {
496
- "epsilon" : epsilon ,
497
- "is_test" : False ,
498
- "data_layout" : data_layout ,
499
- "use_mkldnn" : False ,
500
- "fuse_with_relu" : self .fuse_with_relu ,
501
- "use_global_stats" : self .use_global_stats ,
502
- }
503
- if self .use_momentum_variable :
504
- inputs ["MomentumTensor" ] = block .var ("momentum_var" )
505
- else :
506
- attrs ["momentum" ] = momentum
507
-
508
- outputs = {
509
- "Y" : block .var ("y" ),
510
- "MeanOut" : block .var ("mean" ), # share memory
511
- "VarianceOut" : block .var ("variance" ), # share memory
512
- "SavedMean" : block .var ("saved_mean" ),
513
- "SavedVariance" : block .var ("saved_variance" ),
514
- }
515
- block .create_var (name = "reserve_space" , dtype = "float32" )
516
- outputs ["ReserveSpace" ] = block .var ("reserve_space" )
517
- bn_op = block .append_op (
518
- type = "batch_norm" , inputs = inputs , outputs = outputs , attrs = attrs
519
- )
520
- block .create_var (name = "y@GRAD" , dtype = "float32" , shape = y .shape )
521
-
522
- # generate backward op_desc
523
- grad_op_desc_list , op_grad_to_var = core .get_grad_op_desc (
524
- bn_op .desc , self .no_grad_set , []
525
- )
526
- grad_op_desc = grad_op_desc_list [0 ]
527
- new_op_desc = block .desc .append_op ()
528
- new_op_desc .copy_from (grad_op_desc )
529
- for var_name in grad_op_desc .output_arg_names ():
530
- block .desc .var (var_name .encode ("ascii" ))
531
- grad_op_desc .infer_var_type (block .desc )
532
- grad_op_desc .infer_shape (block .desc )
533
- for arg in grad_op_desc .output_arg_names ():
534
- grad_var = block .desc .find_var (arg .encode ("ascii" ))
535
- grad_var .set_dtype (core .VarDesc .VarType .FP32 )
536
-
537
- program ._sync_with_cpp ()
538
-
539
- exe = base .Executor (place )
540
- out = exe .run (
541
- program ,
542
- feed = {
543
- name : var_dict [name ]
544
- for name in [
545
- "x" ,
546
- "scale" ,
547
- "bias" ,
548
- "mean" ,
549
- "variance" ,
550
- "y@GRAD" ,
551
- "momentum_var" ,
552
- ]
553
- },
554
- fetch_list = self .fetch_list ,
555
- )
556
557
557
558
for id , name in enumerate (self .fetch_list ):
558
559
if name == "variance" :
0 commit comments