Skip to content

Commit cb64200

Browse files
authored
Fit new paddle (PaddlePaddle#1044)
1 parent 380bce6 commit cb64200

File tree

2 files changed

+85
-27
lines changed

2 files changed

+85
-27
lines changed

paddleslim/nas/ofa/layers.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -963,19 +963,48 @@ def forward(self, input):
963963
"use_mkldnn", False, "fuse_with_relu", False,
964964
"use_global_stats", self._use_global_stats,
965965
"trainable_statistics", trainable_statistics)
966-
967-
if feature_dim != self._mean.shape[0]:
968-
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
969-
variance, mean_out_tmp,
970-
variance_out_tmp, *attrs)
971-
self._mean[:feature_dim].set_value(mean)
972-
self._variance[:feature_dim].set_value(variance)
973-
mean_out[:feature_dim].set_value(mean_out_tmp)
974-
variance_out[:feature_dim].set_value(variance_out_tmp)
975-
else:
976-
batch_norm_out = core.ops.batch_norm(input, weight, bias,
977-
self._mean, self._variance,
978-
mean_out, variance_out, *attrs)
966+
try:
967+
from paddle import _C_ops
968+
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
969+
if in_dygraph_mode():
970+
if feature_dim != self._mean.shape[0]:
971+
batch_norm_out = _C_ops.final_state_batch_norm(
972+
input, weight, bias, mean, variance, mean_out_tmp,
973+
variance_out_tmp, *attrs)
974+
self._mean[:feature_dim].set_value(mean)
975+
self._variance[:feature_dim].set_value(variance)
976+
mean_out[:feature_dim].set_value(mean_out_tmp)
977+
variance_out[:feature_dim].set_value(variance_out_tmp)
978+
else:
979+
batch_norm_out = _C_ops.final_state_batch_norm(
980+
input, weight, bias, self._mean, self._variance,
981+
mean_out, variance_out, *attrs)
982+
elif _in_legacy_dygraph():
983+
if feature_dim != self._mean.shape[0]:
984+
batch_norm_out = core.ops.batch_norm(
985+
input, weight, bias, mean, variance, None, mean_out_tmp,
986+
variance_out_tmp, *attrs)
987+
self._mean[:feature_dim].set_value(mean)
988+
self._variance[:feature_dim].set_value(variance)
989+
mean_out[:feature_dim].set_value(mean_out_tmp)
990+
variance_out[:feature_dim].set_value(variance_out_tmp)
991+
else:
992+
batch_norm_out = core.ops.batch_norm(
993+
input, weight, bias, self._mean, self._variance, None,
994+
mean_out, variance_out, *attrs)
995+
except:
996+
if feature_dim != self._mean.shape[0]:
997+
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
998+
variance, mean_out_tmp,
999+
variance_out_tmp, *attrs)
1000+
self._mean[:feature_dim].set_value(mean)
1001+
self._variance[:feature_dim].set_value(variance)
1002+
mean_out[:feature_dim].set_value(mean_out_tmp)
1003+
variance_out[:feature_dim].set_value(variance_out_tmp)
1004+
else:
1005+
batch_norm_out = core.ops.batch_norm(
1006+
input, weight, bias, self._mean, self._variance, mean_out,
1007+
variance_out, *attrs)
9791008

9801009
self.cur_config = {'prune_dim': feature_dim}
9811010
return batch_norm_out[0]
@@ -1246,4 +1275,4 @@ def forward(self, input, expand_ratio=None, channel=None):
12461275
weight=weight,
12471276
padding_idx=self._padding_idx,
12481277
sparse=self._sparse,
1249-
name=self._name)
1278+
name=self._name)

paddleslim/nas/ofa/layers_old.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -903,19 +903,48 @@ def forward(self, input):
903903
"use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu,
904904
"use_global_stats", self._use_global_stats,
905905
'trainable_statistics', self._trainable_statistics)
906-
907-
if feature_dim != self._mean.shape[0]:
908-
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
909-
variance, mean_out_tmp,
910-
variance_out_tmp, *attrs)
911-
self._mean[:feature_dim] = mean
912-
self._variance[:feature_dim] = variance
913-
mean_out[:feature_dim] = mean_out_tmp
914-
variance_out[:feature_dim] = variance_out_tmp
915-
else:
916-
batch_norm_out = core.ops.batch_norm(input, weight, bias,
917-
self._mean, self._variance,
918-
mean_out, variance_out, *attrs)
906+
try:
907+
from paddle import _C_ops
908+
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
909+
if in_dygraph_mode():
910+
if feature_dim != self._mean.shape[0]:
911+
batch_norm_out = _C_ops.final_state_batch_norm(
912+
input, weight, bias, mean, variance, mean_out_tmp,
913+
variance_out_tmp, *attrs)
914+
self._mean[:feature_dim] = mean
915+
self._variance[:feature_dim] = variance
916+
mean_out[:feature_dim] = mean_out_tmp
917+
variance_out[:feature_dim] = variance_out_tmp
918+
else:
919+
batch_norm_out = core.ops.batch_norm(
920+
input, weight, bias, self._mean, self._variance,
921+
mean_out, variance_out, *attrs)
922+
elif _in_legacy_dygraph():
923+
if feature_dim != self._mean.shape[0]:
924+
batch_norm_out = core.ops.batch_norm(
925+
input, weight, bias, mean, variance, None, mean_out_tmp,
926+
variance_out_tmp, *attrs)
927+
self._mean[:feature_dim].set_value(mean)
928+
self._variance[:feature_dim].set_value(variance)
929+
mean_out[:feature_dim].set_value(mean_out_tmp)
930+
variance_out[:feature_dim].set_value(variance_out_tmp)
931+
else:
932+
batch_norm_out = core.ops.batch_norm(
933+
input, weight, bias, self._mean, self._variance, None,
934+
mean_out, variance_out, *attrs)
935+
except:
936+
if feature_dim != self._mean.shape[0]:
937+
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
938+
variance, mean_out_tmp,
939+
variance_out_tmp, *attrs)
940+
self._mean[:feature_dim].set_value(mean)
941+
self._variance[:feature_dim].set_value(variance)
942+
mean_out[:feature_dim].set_value(mean_out_tmp)
943+
variance_out[:feature_dim].set_value(variance_out_tmp)
944+
else:
945+
batch_norm_out = core.ops.batch_norm(
946+
input, weight, bias, self._mean, self._variance, mean_out,
947+
variance_out, *attrs)
919948

920949
return dygraph_utils._append_activation_in_dygraph(
921950
batch_norm_out[0], act=self._act)

0 commit comments

Comments
 (0)