Skip to content

Commit b1c22b6

Browse files
committed
Fix MultiGradientMachine error
1 parent ca62c10 commit b1c22b6

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

paddle/gserver/gradientmachines/MultiGradientMachine.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const {
346346
void MultiGradientMachine::eval(Evaluator* evaluator) const {
347347
for (auto& thread : threads_) {
348348
SetDevice device(thread->getDeviceId());
349-
thread->getGradientMachine()->eval(evaluator);
349+
if (thread->hasInputData()) {
350+
thread->getGradientMachine()->eval(evaluator);
351+
}
350352
}
351353
}
352354

@@ -356,14 +358,20 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
356358
REGISTER_TIMER("waitOutArgs");
357359
thread->waitOutArgsReady();
358360
}
359-
outArgs_.resize(threads_[0]->getOutArgs().size());
361+
// outArgs_.size() only need to be calculated once.
362+
static int size = threads_[threads_.size() - 1]->getOutArgs().size();
363+
outArgs_.resize(size);
360364

361365
REGISTER_TIMER("copyOutArgs");
362366
for (size_t i = 0; i < outArgs_.size(); ++i) {
363367
std::vector<Argument> args;
364368
args.reserve(threads_.size());
365369
for (auto& thread : threads_) {
366-
args.push_back(thread->getOutArgs()[i]);
370+
// If the thread input is empty, then the output is empty.
371+
auto tmp = thread->getOutArgs();
372+
if (tmp.size() > 0) {
373+
args.push_back(tmp[i]);
374+
}
367375
}
368376
outArgs_[i].concat(args, useGpu_, outArgStream_, passType);
369377
}
@@ -534,7 +542,7 @@ void TrainerThread::prefetch() {
534542
void TrainerThread::forward() {
535543
if (!inArgsCopied_) {
536544
REGISTER_TIMER("copyInArgs");
537-
copyInArgs();
545+
batchSize_ = copyInArgs();
538546
} else {
539547
inArgsCopied_ = false;
540548
}
@@ -564,7 +572,12 @@ void TrainerThread::forward() {
564572

565573
{
566574
REGISTER_TIMER("thread_forward");
567-
gradientMachine_->forward(inArgs_, &outArgs_, multiMachine_->getPassType());
575+
if (batchSize_ > 0) {
576+
gradientMachine_->forward(
577+
inArgs_, &outArgs_, multiMachine_->getPassType());
578+
} else {
579+
outArgs_.clear();
580+
}
568581
}
569582
outArgsReadySem_.post();
570583
}
@@ -574,7 +587,13 @@ void TrainerThread::backward() {
574587
if (multiMachine_->isPassGrad()) {
575588
copyOutputGrad();
576589
}
577-
gradientMachine_->backward(backwardCallback_);
590+
if (batchSize_ > 0) {
591+
gradientMachine_->backward(backwardCallback_);
592+
} else {
593+
for (size_t i = parameters_.size(); i > 0; i--) {
594+
backwardCallback(parameters_[i - 1].get());
595+
}
596+
}
578597
if (multiMachine_->hasNonstaticCpuParamters()) {
579598
mergeCpuGradients();
580599
}
@@ -732,7 +751,7 @@ void TrainerThread::notifyValueReady(int paramId) {
732751
notifyValueDispatch(paramId);
733752
}
734753

735-
void TrainerThread::copyInArgs() {
754+
int TrainerThread::copyInArgs() {
736755
const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs();
737756
int numThreads = multiMachine_->getAllThreads().size();
738757
int32_t numSequences = fullInArgs[0].getNumSequences();
@@ -748,7 +767,7 @@ void TrainerThread::copyInArgs() {
748767
}
749768

750769
if (copySize == 0) {
751-
return;
770+
return 0;
752771
}
753772

754773
for (size_t i = 0; i < fullInArgs.size(); i++) {
@@ -758,6 +777,7 @@ void TrainerThread::copyInArgs() {
758777
copySize,
759778
FLAGS_parallel_nn ? false : multiMachine_->useGpu());
760779
}
780+
return copySize;
761781
}
762782

763783
void TrainerThread::mergeCpuGradients() {

paddle/gserver/gradientmachines/MultiGradientMachine.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ class TrainerThread {
387387
/// copy the output gradient from the main GradientMachine.
388388
void copyOutputGrad();
389389

390+
/// Whether the thread has input data.
391+
bool hasInputData() { return batchSize_ != 0; }
392+
390393
protected:
391394
void mergeCpuGradients();
392395

@@ -407,7 +410,7 @@ class TrainerThread {
407410
void copyGradToBufferThread();
408411
void gradCollectThread();
409412

410-
void copyInArgs();
413+
int copyInArgs();
411414
void forward();
412415
void backward();
413416
void backwardCallback(Parameter* para);
@@ -467,6 +470,8 @@ class TrainerThread {
467470

468471
/// indicate whether inArgs is copied before forward()
469472
bool inArgsCopied_;
473+
474+
int batchSize_;
470475
};
471476

472477
} // namespace paddle

0 commit comments

Comments
 (0)