Skip to content

Commit 101a9a4

Browse files
authored
Merge pull request #1566 from hedaoyuan/multi-gradient-machine-error
Fix MultiGradientMachine error
2 parents 1fa7302 + df8a5af commit 101a9a4

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

paddle/gserver/gradientmachines/MultiGradientMachine.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const {
346346
void MultiGradientMachine::eval(Evaluator* evaluator) const {
347347
for (auto& thread : threads_) {
348348
SetDevice device(thread->getDeviceId());
349-
thread->getGradientMachine()->eval(evaluator);
349+
if (thread->hasInputData()) {
350+
thread->getGradientMachine()->eval(evaluator);
351+
}
350352
}
351353
}
352354

@@ -356,14 +358,19 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
356358
REGISTER_TIMER("waitOutArgs");
357359
thread->waitOutArgsReady();
358360
}
359-
outArgs_.resize(threads_[0]->getOutArgs().size());
361+
362+
outArgs_.resize(threads_[threads_.size() - 1]->getOutArgs().size());
360363

361364
REGISTER_TIMER("copyOutArgs");
362365
for (size_t i = 0; i < outArgs_.size(); ++i) {
363366
std::vector<Argument> args;
364367
args.reserve(threads_.size());
365368
for (auto& thread : threads_) {
366-
args.push_back(thread->getOutArgs()[i]);
369+
// If the thread input is empty, then the output is empty.
370+
auto tmp = thread->getOutArgs();
371+
if (tmp.size() > 0) {
372+
args.push_back(tmp[i]);
373+
}
367374
}
368375
outArgs_[i].concat(args, useGpu_, outArgStream_, passType);
369376
}
@@ -534,7 +541,7 @@ void TrainerThread::prefetch() {
534541
void TrainerThread::forward() {
535542
if (!inArgsCopied_) {
536543
REGISTER_TIMER("copyInArgs");
537-
copyInArgs();
544+
batchSize_ = copyInArgs();
538545
} else {
539546
inArgsCopied_ = false;
540547
}
@@ -564,7 +571,12 @@ void TrainerThread::forward() {
564571

565572
{
566573
REGISTER_TIMER("thread_forward");
567-
gradientMachine_->forward(inArgs_, &outArgs_, multiMachine_->getPassType());
574+
if (batchSize_ > 0) {
575+
gradientMachine_->forward(
576+
inArgs_, &outArgs_, multiMachine_->getPassType());
577+
} else {
578+
outArgs_.clear();
579+
}
568580
}
569581
outArgsReadySem_.post();
570582
}
@@ -574,7 +586,13 @@ void TrainerThread::backward() {
574586
if (multiMachine_->isPassGrad()) {
575587
copyOutputGrad();
576588
}
577-
gradientMachine_->backward(backwardCallback_);
589+
if (batchSize_ > 0) {
590+
gradientMachine_->backward(backwardCallback_);
591+
} else {
592+
for (size_t i = parameters_.size(); i > 0; i--) {
593+
backwardCallback(parameters_[i - 1].get());
594+
}
595+
}
578596
if (multiMachine_->hasNonstaticCpuParamters()) {
579597
mergeCpuGradients();
580598
}
@@ -732,7 +750,7 @@ void TrainerThread::notifyValueReady(int paramId) {
732750
notifyValueDispatch(paramId);
733751
}
734752

735-
void TrainerThread::copyInArgs() {
753+
int TrainerThread::copyInArgs() {
736754
const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs();
737755
int numThreads = multiMachine_->getAllThreads().size();
738756
int32_t numSequences = fullInArgs[0].getNumSequences();
@@ -748,7 +766,7 @@ void TrainerThread::copyInArgs() {
748766
}
749767

750768
if (copySize == 0) {
751-
return;
769+
return 0;
752770
}
753771

754772
for (size_t i = 0; i < fullInArgs.size(); i++) {
@@ -758,6 +776,7 @@ void TrainerThread::copyInArgs() {
758776
copySize,
759777
FLAGS_parallel_nn ? false : multiMachine_->useGpu());
760778
}
779+
return copySize;
761780
}
762781

763782
void TrainerThread::mergeCpuGradients() {

paddle/gserver/gradientmachines/MultiGradientMachine.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ class TrainerThread {
387387
/// copy the output gradient from the main GradientMachine.
388388
void copyOutputGrad();
389389

390+
/// Whether the thread has input data.
391+
bool hasInputData() { return batchSize_ != 0; }
392+
390393
protected:
391394
void mergeCpuGradients();
392395

@@ -407,7 +410,7 @@ class TrainerThread {
407410
void copyGradToBufferThread();
408411
void gradCollectThread();
409412

410-
void copyInArgs();
413+
int copyInArgs();
411414
void forward();
412415
void backward();
413416
void backwardCallback(Parameter* para);
@@ -467,6 +470,7 @@ class TrainerThread {
467470

468471
/// indicate whether inArgs is copied before forward()
469472
bool inArgsCopied_;
473+
int batchSize_;
470474
};
471475

472476
} // namespace paddle

0 commit comments

Comments
 (0)