Skip to content

Commit 34ef38c

Browse files
authored
Fix optimizer op infershape failed in dygraph multi-cards mode (#21374) (#22112)
* add param & grad shape check for sgd op * add _reshape_inplece interface for dygraph parallel * refine unittest based paddle/models scripts, test=develop * add unittest for parallel grad fuse, test=develop
1 parent eb6d339 commit 34ef38c

File tree

5 files changed

+123
-8
lines changed

5 files changed

+123
-8
lines changed

paddle/fluid/operators/optimizers/sgd_op.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,15 @@ class SGDOp : public framework::OperatorWithKernel {
4040
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
4141
"Learning rate should have 1 element");
4242
auto param_dim = ctx->GetInputDim("Param");
43-
// TODO(qijun): check dimensions of Param and Grad at compile
44-
// and runtime.
43+
if (ctx->GetInputsVarType("Grad")[0] ==
44+
framework::proto::VarType::LOD_TENSOR) {
45+
PADDLE_ENFORCE_EQ(
46+
param_dim, ctx->GetInputDim("Grad"),
47+
platform::errors::InvalidArgument(
48+
"SGD Operator's input Param and Grad dimensions do not match. "
49+
"The Param shape is [%s], but the Grad shape is [%s].",
50+
param_dim, ctx->GetInputDim("Grad")));
51+
}
4552
ctx->SetOutputDim("ParamOut", param_dim);
4653
}
4754

python/paddle/fluid/dygraph/parallel.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@ def _coalesce_tensors(self, var_groups):
184184
[coalesced_grad, grad_vars, g_var_shapes])
185185
return coalesced_grads_and_grad_vars
186186

187+
def _reshape_inplace(self, x, shape):
188+
x_shape = self._helper.create_variable_for_type_inference(dtype=x.dtype)
189+
self._helper.append_op(
190+
type="reshape2",
191+
inputs={'X': x},
192+
attrs={'shape': shape},
193+
outputs={'Out': x,
194+
'XShape': x_shape})
195+
187196
def _split_tensors(self, coalesced_grads_and_grad_vars):
188197
from ..layers import nn
189198
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
@@ -195,7 +204,8 @@ def _split_tensors(self, coalesced_grads_and_grad_vars):
195204
attrs={'sections': grad_var_len,
196205
'axis': 0})
197206
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
198-
nn.reshape(x=g_var, shape=g_shape, inplace=True)
207+
self._reshape_inplace(x=g_var, shape=g_shape)
208+
assert g_var.shape == g_shape
199209

200210
@no_grad
201211
def apply_collective_grads(self):

python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_model(self):
114114
model = MNIST("mnist")
115115
train_reader = paddle.batch(
116116
paddle.dataset.mnist.train(), batch_size=2, drop_last=True)
117-
opt = fluid.optimizer.SGD(learning_rate=1e-3)
117+
opt = fluid.optimizer.Adam(learning_rate=1e-3)
118118
return model, train_reader, opt
119119

120120
def run_one_loop(self, model, opt, data):

python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,26 @@
3333
import math
3434
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
3535

36+
batch_size = 64
3637
momentum_rate = 0.9
3738
l2_decay = 1.2e-4
3839

40+
train_parameters = {
41+
"input_size": [3, 224, 224],
42+
"input_mean": [0.485, 0.456, 0.406],
43+
"input_std": [0.229, 0.224, 0.225],
44+
"learning_strategy": {
45+
"name": "cosine_decay",
46+
"batch_size": batch_size,
47+
"epochs": [40, 80, 100],
48+
"steps": [0.1, 0.01, 0.001, 0.0001]
49+
},
50+
"batch_size": batch_size,
51+
"lr": 0.0125,
52+
"total_images": 6149,
53+
"num_epochs": 200
54+
}
55+
3956

4057
def optimizer_setting(params):
4158
ls = params["learning_strategy"]
@@ -300,11 +317,10 @@ def get_model(self):
300317
model = SeResNeXt("se-resnext")
301318
train_reader = paddle.batch(
302319
paddle.dataset.flowers.test(use_xmap=False),
303-
batch_size=4,
320+
batch_size=train_parameters["batch_size"],
304321
drop_last=True)
305-
306-
opt = fluid.optimizer.SGD(learning_rate=1e-3)
307-
return model, train_reader, opt
322+
optimizer = optimizer_setting(train_parameters)
323+
return model, train_reader, optimizer
308324

309325
def run_one_loop(self, model, opt, data):
310326
bs = len(data)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import contextlib
16+
import unittest
17+
import numpy as np
18+
from collections import OrderedDict
19+
20+
import paddle
21+
import paddle.fluid as fluid
22+
from paddle.fluid import core
23+
from paddle.fluid.dygraph.parallel import DataParallel
24+
from paddle.fluid.dygraph.base import to_variable
25+
26+
27+
class MyLayer(fluid.Layer):
28+
def __init__(self, name_scope):
29+
super(MyLayer, self).__init__(name_scope)
30+
31+
def forward(self, inputs):
32+
x = fluid.layers.relu(inputs)
33+
x = fluid.layers.elementwise_mul(x, x)
34+
x = fluid.layers.reduce_sum(x)
35+
return [x]
36+
37+
38+
class TestImperativeParallelCoalesceSplit(unittest.TestCase):
39+
def test_coalesce_split(self):
40+
with fluid.dygraph.guard():
41+
test_layer = MyLayer("test_layer")
42+
strategy = core.ParallelStrategy()
43+
test_layer = DataParallel(test_layer, strategy)
44+
45+
# test variables prepare
46+
vars = []
47+
vars.append(to_variable(np.random.random([2, 3]).astype("float32")))
48+
vars.append(to_variable(np.random.random([4, 9]).astype("float32")))
49+
vars.append(
50+
to_variable(np.random.random([10, 1]).astype("float32")))
51+
var_groups = OrderedDict()
52+
var_groups.setdefault(0, vars)
53+
54+
# record shapes
55+
orig_var_shapes = []
56+
for var in vars:
57+
orig_var_shapes.append(var.shape)
58+
59+
# execute interface
60+
coalesced_vars = test_layer._coalesce_tensors(var_groups)
61+
test_layer._split_tensors(coalesced_vars)
62+
63+
# compare
64+
for orig_var_shape, var in zip(orig_var_shapes, vars):
65+
self.assertEqual(orig_var_shape, var.shape)
66+
67+
def test_reshape_inplace(self):
68+
with fluid.dygraph.guard():
69+
test_layer = MyLayer("test_layer")
70+
strategy = core.ParallelStrategy()
71+
test_layer = DataParallel(test_layer, strategy)
72+
73+
ori_shape = [2, 25]
74+
new_shape = [5, 10]
75+
x_data = np.random.random(ori_shape).astype("float32")
76+
x = to_variable(x_data)
77+
test_layer._reshape_inplace(x, new_shape)
78+
self.assertEqual(x.shape, new_shape)
79+
80+
81+
if __name__ == '__main__':
82+
unittest.main()

0 commit comments

Comments
 (0)