Skip to content

Commit da95178

Browse files
authored
[2.0API] fix weight_norm support negative dim and unittest in convert_syncbn (#27108) (#27157)
* fix 2.0api, test=develop * fix, test=develop
1 parent 2118868 commit da95178

File tree

4 files changed

+27
-8
lines changed

4 files changed

+27
-8
lines changed

python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def test_check_output(self):
121121
before_weight = linear.weight.numpy()
122122
if self.dim == None:
123123
self.dim = -1
124+
125+
if self.dim != -1:
126+
self.dim = (self.dim + len(before_weight)) % len(before_weight)
124127
wn = weight_norm(linear, dim=self.dim)
125128
outputs = []
126129
for name, data in self.data.items():
@@ -158,6 +161,13 @@ def init_test_case(self):
158161
self.dim = 3
159162

160163

164+
class TestDygraphWeightNormCase4(TestDygraphWeightNorm):
165+
def init_test_case(self):
166+
self.batch_size = 3
167+
self.data_desc = (['x', [2, 3, 3]], )
168+
self.dim = -3
169+
170+
161171
class TestDygraphRemoveWeightNorm(unittest.TestCase):
162172
def setUp(self):
163173
self.init_test_case()

python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,15 @@ def test_convert(self):
227227
return
228228

229229
with program_guard(Program(), Program()):
230+
compare_model = paddle.nn.Sequential(
231+
paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5))
230232
model = paddle.nn.Sequential(
231233
paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5))
232-
sync_model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
233-
for idx, sublayer in enumerate(model.sublayers()):
234+
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
235+
for idx, sublayer in enumerate(compare_model.sublayers()):
234236
if isinstance(sublayer, paddle.nn.BatchNorm2d):
235237
self.assertEqual(
236-
isinstance(sync_model[idx], paddle.nn.SyncBatchNorm),
237-
True)
238+
isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
238239

239240

240241
if __name__ == '__main__':

python/paddle/nn/layer/norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,10 +1130,10 @@ def convert_sync_batchnorm(cls, layer):
11301130
"""
11311131
layer_output = layer
11321132
if isinstance(layer, _BatchNormBase):
1133-
layer_output = SyncBatchNorm(layer._num_features, layer._epsilon,
1134-
layer._momentum, layer._weight_attr,
1135-
layer._bias_attr, layer._data_format,
1136-
layer._name)
1133+
layer_output = SyncBatchNorm(
1134+
layer._num_features, layer._momentum, layer._epsilon,
1135+
layer._weight_attr, layer._bias_attr, layer._data_format,
1136+
layer._track_running_stats, layer._name)
11371137

11381138
if layer._weight_attr != False and layer._bias_attr != False:
11391139
with no_grad():

python/paddle/nn/utils/weight_norm_hook.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ def apply(layer, name, dim):
112112
if dim is None:
113113
dim = -1
114114

115+
# support dim is negative numeber, (dim = -1) == (dim = None)
116+
weight_dim = len(layer._parameters[name].shape)
117+
assert (
118+
dim < weight_dim and dim >= -1 * weight_dim
119+
), "dim must set between [-R, R), R means the dimension of weight."
120+
if dim != -1:
121+
dim = (dim + weight_dim) % weight_dim
122+
115123
fn = WeightNorm(name, dim)
116124

117125
w = getattr(layer, name)

0 commit comments

Comments
 (0)