@@ -963,19 +963,48 @@ def forward(self, input):
963
963
"use_mkldnn" , False , "fuse_with_relu" , False ,
964
964
"use_global_stats" , self ._use_global_stats ,
965
965
"trainable_statistics" , trainable_statistics )
966
-
967
- if feature_dim != self ._mean .shape [0 ]:
968
- batch_norm_out = core .ops .batch_norm (input , weight , bias , mean ,
969
- variance , mean_out_tmp ,
970
- variance_out_tmp , * attrs )
971
- self ._mean [:feature_dim ].set_value (mean )
972
- self ._variance [:feature_dim ].set_value (variance )
973
- mean_out [:feature_dim ].set_value (mean_out_tmp )
974
- variance_out [:feature_dim ].set_value (variance_out_tmp )
975
- else :
976
- batch_norm_out = core .ops .batch_norm (input , weight , bias ,
977
- self ._mean , self ._variance ,
978
- mean_out , variance_out , * attrs )
966
+ try :
967
+ from paddle import _C_ops
968
+ from paddle .fluid .framework import in_dygraph_mode , _in_legacy_dygraph
969
+ if in_dygraph_mode ():
970
+ if feature_dim != self ._mean .shape [0 ]:
971
+ batch_norm_out = _C_ops .final_state_batch_norm (
972
+ input , weight , bias , mean , variance , mean_out_tmp ,
973
+ variance_out_tmp , * attrs )
974
+ self ._mean [:feature_dim ].set_value (mean )
975
+ self ._variance [:feature_dim ].set_value (variance )
976
+ mean_out [:feature_dim ].set_value (mean_out_tmp )
977
+ variance_out [:feature_dim ].set_value (variance_out_tmp )
978
+ else :
979
+ batch_norm_out = _C_ops .final_state_batch_norm (
980
+ input , weight , bias , self ._mean , self ._variance ,
981
+ mean_out , variance_out , * attrs )
982
+ elif _in_legacy_dygraph ():
983
+ if feature_dim != self ._mean .shape [0 ]:
984
+ batch_norm_out = core .ops .batch_norm (
985
+ input , weight , bias , mean , variance , None , mean_out_tmp ,
986
+ variance_out_tmp , * attrs )
987
+ self ._mean [:feature_dim ].set_value (mean )
988
+ self ._variance [:feature_dim ].set_value (variance )
989
+ mean_out [:feature_dim ].set_value (mean_out_tmp )
990
+ variance_out [:feature_dim ].set_value (variance_out_tmp )
991
+ else :
992
+ batch_norm_out = core .ops .batch_norm (
993
+ input , weight , bias , self ._mean , self ._variance , None ,
994
+ mean_out , variance_out , * attrs )
995
+ except :
996
+ if feature_dim != self ._mean .shape [0 ]:
997
+ batch_norm_out = core .ops .batch_norm (input , weight , bias , mean ,
998
+ variance , mean_out_tmp ,
999
+ variance_out_tmp , * attrs )
1000
+ self ._mean [:feature_dim ].set_value (mean )
1001
+ self ._variance [:feature_dim ].set_value (variance )
1002
+ mean_out [:feature_dim ].set_value (mean_out_tmp )
1003
+ variance_out [:feature_dim ].set_value (variance_out_tmp )
1004
+ else :
1005
+ batch_norm_out = core .ops .batch_norm (
1006
+ input , weight , bias , self ._mean , self ._variance , mean_out ,
1007
+ variance_out , * attrs )
979
1008
980
1009
self .cur_config = {'prune_dim' : feature_dim }
981
1010
return batch_norm_out [0 ]
@@ -1246,4 +1275,4 @@ def forward(self, input, expand_ratio=None, channel=None):
1246
1275
weight = weight ,
1247
1276
padding_idx = self ._padding_idx ,
1248
1277
sparse = self ._sparse ,
1249
- name = self ._name )
1278
+ name = self ._name )
0 commit comments