@@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const {
346
346
void MultiGradientMachine::eval (Evaluator* evaluator) const {
347
347
for (auto & thread : threads_) {
348
348
SetDevice device (thread->getDeviceId ());
349
- thread->getGradientMachine ()->eval (evaluator);
349
+ if (thread->hasInputData ()) {
350
+ thread->getGradientMachine ()->eval (evaluator);
351
+ }
350
352
}
351
353
}
352
354
@@ -356,14 +358,19 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
356
358
REGISTER_TIMER (" waitOutArgs" );
357
359
thread->waitOutArgsReady ();
358
360
}
359
- outArgs_.resize (threads_[0 ]->getOutArgs ().size ());
361
+
362
+ outArgs_.resize (threads_[threads_.size () - 1 ]->getOutArgs ().size ());
360
363
361
364
REGISTER_TIMER (" copyOutArgs" );
362
365
for (size_t i = 0 ; i < outArgs_.size (); ++i) {
363
366
std::vector<Argument> args;
364
367
args.reserve (threads_.size ());
365
368
for (auto & thread : threads_) {
366
- args.push_back (thread->getOutArgs ()[i]);
369
+ // If the thread input is empty, then the output is empty.
370
+ auto tmp = thread->getOutArgs ();
371
+ if (tmp.size () > 0 ) {
372
+ args.push_back (tmp[i]);
373
+ }
367
374
}
368
375
outArgs_[i].concat (args, useGpu_, outArgStream_, passType);
369
376
}
@@ -534,7 +541,7 @@ void TrainerThread::prefetch() {
534
541
void TrainerThread::forward () {
535
542
if (!inArgsCopied_) {
536
543
REGISTER_TIMER (" copyInArgs" );
537
- copyInArgs ();
544
+ batchSize_ = copyInArgs ();
538
545
} else {
539
546
inArgsCopied_ = false ;
540
547
}
@@ -564,7 +571,12 @@ void TrainerThread::forward() {
564
571
565
572
{
566
573
REGISTER_TIMER (" thread_forward" );
567
- gradientMachine_->forward (inArgs_, &outArgs_, multiMachine_->getPassType ());
574
+ if (batchSize_ > 0 ) {
575
+ gradientMachine_->forward (
576
+ inArgs_, &outArgs_, multiMachine_->getPassType ());
577
+ } else {
578
+ outArgs_.clear ();
579
+ }
568
580
}
569
581
outArgsReadySem_.post ();
570
582
}
@@ -574,7 +586,13 @@ void TrainerThread::backward() {
574
586
if (multiMachine_->isPassGrad ()) {
575
587
copyOutputGrad ();
576
588
}
577
- gradientMachine_->backward (backwardCallback_);
589
+ if (batchSize_ > 0 ) {
590
+ gradientMachine_->backward (backwardCallback_);
591
+ } else {
592
+ for (size_t i = parameters_.size (); i > 0 ; i--) {
593
+ backwardCallback (parameters_[i - 1 ].get ());
594
+ }
595
+ }
578
596
if (multiMachine_->hasNonstaticCpuParamters ()) {
579
597
mergeCpuGradients ();
580
598
}
@@ -732,7 +750,7 @@ void TrainerThread::notifyValueReady(int paramId) {
732
750
notifyValueDispatch (paramId);
733
751
}
734
752
735
- void TrainerThread::copyInArgs () {
753
+ int TrainerThread::copyInArgs () {
736
754
const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs ();
737
755
int numThreads = multiMachine_->getAllThreads ().size ();
738
756
int32_t numSequences = fullInArgs[0 ].getNumSequences ();
@@ -748,7 +766,7 @@ void TrainerThread::copyInArgs() {
748
766
}
749
767
750
768
if (copySize == 0 ) {
751
- return ;
769
+ return 0 ;
752
770
}
753
771
754
772
for (size_t i = 0 ; i < fullInArgs.size (); i++) {
@@ -758,6 +776,7 @@ void TrainerThread::copyInArgs() {
758
776
copySize,
759
777
FLAGS_parallel_nn ? false : multiMachine_->useGpu ());
760
778
}
779
+ return copySize;
761
780
}
762
781
763
782
void TrainerThread::mergeCpuGradients () {
0 commit comments