@@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const {
346
346
void MultiGradientMachine::eval (Evaluator* evaluator) const {
347
347
for (auto & thread : threads_) {
348
348
SetDevice device (thread->getDeviceId ());
349
- thread->getGradientMachine ()->eval (evaluator);
349
+ if (thread->hasInputData ()) {
350
+ thread->getGradientMachine ()->eval (evaluator);
351
+ }
350
352
}
351
353
}
352
354
@@ -356,14 +358,20 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
356
358
REGISTER_TIMER (" waitOutArgs" );
357
359
thread->waitOutArgsReady ();
358
360
}
359
- outArgs_.resize (threads_[0 ]->getOutArgs ().size ());
361
+ // outArgs_.size() only need to be calculated once.
362
+ static int size = threads_[threads_.size () - 1 ]->getOutArgs ().size ();
363
+ outArgs_.resize (size);
360
364
361
365
REGISTER_TIMER (" copyOutArgs" );
362
366
for (size_t i = 0 ; i < outArgs_.size (); ++i) {
363
367
std::vector<Argument> args;
364
368
args.reserve (threads_.size ());
365
369
for (auto & thread : threads_) {
366
- args.push_back (thread->getOutArgs ()[i]);
370
+ // If the thread input is empty, then the output is empty.
371
+ auto tmp = thread->getOutArgs ();
372
+ if (tmp.size () > 0 ) {
373
+ args.push_back (tmp[i]);
374
+ }
367
375
}
368
376
outArgs_[i].concat (args, useGpu_, outArgStream_, passType);
369
377
}
@@ -534,7 +542,7 @@ void TrainerThread::prefetch() {
534
542
void TrainerThread::forward () {
535
543
if (!inArgsCopied_) {
536
544
REGISTER_TIMER (" copyInArgs" );
537
- copyInArgs ();
545
+ batchSize_ = copyInArgs ();
538
546
} else {
539
547
inArgsCopied_ = false ;
540
548
}
@@ -564,7 +572,12 @@ void TrainerThread::forward() {
564
572
565
573
{
566
574
REGISTER_TIMER (" thread_forward" );
567
- gradientMachine_->forward (inArgs_, &outArgs_, multiMachine_->getPassType ());
575
+ if (batchSize_ > 0 ) {
576
+ gradientMachine_->forward (
577
+ inArgs_, &outArgs_, multiMachine_->getPassType ());
578
+ } else {
579
+ outArgs_.clear ();
580
+ }
568
581
}
569
582
outArgsReadySem_.post ();
570
583
}
@@ -574,7 +587,13 @@ void TrainerThread::backward() {
574
587
if (multiMachine_->isPassGrad ()) {
575
588
copyOutputGrad ();
576
589
}
577
- gradientMachine_->backward (backwardCallback_);
590
+ if (batchSize_ > 0 ) {
591
+ gradientMachine_->backward (backwardCallback_);
592
+ } else {
593
+ for (size_t i = parameters_.size (); i > 0 ; i--) {
594
+ backwardCallback (parameters_[i - 1 ].get ());
595
+ }
596
+ }
578
597
if (multiMachine_->hasNonstaticCpuParamters ()) {
579
598
mergeCpuGradients ();
580
599
}
@@ -732,7 +751,7 @@ void TrainerThread::notifyValueReady(int paramId) {
732
751
notifyValueDispatch (paramId);
733
752
}
734
753
735
- void TrainerThread::copyInArgs () {
754
+ int TrainerThread::copyInArgs () {
736
755
const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs ();
737
756
int numThreads = multiMachine_->getAllThreads ().size ();
738
757
int32_t numSequences = fullInArgs[0 ].getNumSequences ();
@@ -748,7 +767,7 @@ void TrainerThread::copyInArgs() {
748
767
}
749
768
750
769
if (copySize == 0 ) {
751
- return ;
770
+ return 0 ;
752
771
}
753
772
754
773
for (size_t i = 0 ; i < fullInArgs.size (); i++) {
@@ -758,6 +777,7 @@ void TrainerThread::copyInArgs() {
758
777
copySize,
759
778
FLAGS_parallel_nn ? false : multiMachine_->useGpu ());
760
779
}
780
+ return copySize;
761
781
}
762
782
763
783
void TrainerThread::mergeCpuGradients () {
0 commit comments