Skip to content

Commit 07cf302

Browse files
author
Yang Yang
committed
first commit
1 parent 7905e36 commit 07cf302

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

paddle/operators/parallel_do_op.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ static void SplitTensorAndMoveTensorToScopes(
6464
}
6565
}
6666

67+
void WaitOnPlace(const platform::Place place) {
68+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
69+
auto &dev_ctx = *pool.Get(place);
70+
dev_ctx.Wait();
71+
}
72+
6773
void WaitOnPlaces(const std::vector<platform::Place> places) {
6874
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
6975

@@ -214,6 +220,7 @@ class ParallelDoGradOp : public framework::OperatorBase {
214220
auto &tensor_to_merge = sub_scopes[i]->FindVar(s)->Get<LoDTensor>();
215221
if (!(places[i] == places[0])) {
216222
framework::Copy(tensor_to_merge, places[0], tmp);
223+
WaitOnPlace(places[0]);
217224
} else {
218225
tmp->ShareDataWith(tensor_to_merge);
219226
}
@@ -222,12 +229,13 @@ class ParallelDoGradOp : public framework::OperatorBase {
222229
"sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}},
223230
framework::AttributeMap{});
224231
sum_op->Run(*sub_scopes[0], places[0]);
225-
WaitOnPlaces(places);
232+
WaitOnPlace(places[0]);
226233
}
227234

228235
VLOG(3) << result;
229236
framework::Copy(result, place, scope.FindVar(s)->GetMutable<LoDTensor>());
230237
}
238+
WaitOnPlaces(places);
231239
}
232240
};
233241

python/paddle/v2/fluid/tests/test_parallel_op.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515

1616
import paddle.v2.fluid as fluid
1717
import numpy
18-
import sys
19-
# TODO(dzhwinter): get places op check need to be enhanced.
20-
sys.exit(0)
2118

2219

2320
class BaseParallelForTest(unittest.TestCase):
@@ -165,13 +162,13 @@ def test_simple_fc(self):
165162
feed={
166163
'img': numpy.random.random(size=(51, 784)).astype('float32')
167164
},
168-
fetch='fc1.w@GRAD')
165+
fetch=['fc1.w@GRAD'])
169166

170167
def test_fc_with_tiny_data(self):
171168
self.run_test(
172169
callback=ParallelOpTest.__network__,
173170
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')},
174-
fetch='fc1.w@GRAD')
171+
fetch=['fc1.w@GRAD'])
175172

176173

177174
if __name__ == '__main__':

0 commit comments

Comments
 (0)