@@ -200,12 +200,17 @@ class ParallelExecutorPrivate {
200
200
InitNCCLCtxs (scope, bst);
201
201
}
202
202
#endif
203
+ inline bool IsPersistable (const std::string &name) const {
204
+ auto iter = is_persistable_.find (name);
205
+ return iter != is_persistable_.end () && iter->second ;
206
+ }
203
207
204
208
BuildStrategy build_strategy_;
205
209
std::vector<platform::Place> places_;
206
210
std::vector<Scope *> local_scopes_;
207
211
Scope *global_scope_; // not owned
208
212
std::unique_ptr<details::SSAGraphExecutor> executor_;
213
+ std::unordered_map<std::string, bool > is_persistable_;
209
214
210
215
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
211
216
platform::NCCLCommunicator *nccl_ctxs_{nullptr };
@@ -473,6 +478,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
473
478
var_infos.back ().name_ = node->Var ()->Name ();
474
479
var_infos.back ().type_ = node->Var ()->GetType ();
475
480
var_infos.back ().persistable_ = node->Var ()->Persistable ();
481
+ member_->is_persistable_ .emplace (node->Var ()->Name (),
482
+ node->Var ()->Persistable ());
476
483
}
477
484
}
478
485
@@ -642,23 +649,58 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
642
649
643
650
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes (
644
651
const std::unordered_map<std::string, LoDTensor> &tensors) {
645
- for (auto pair : tensors) {
652
+ size_t num_places = member_->places_ .size ();
653
+ for (auto &pair : tensors) {
654
+ bool is_persistable = member_->IsPersistable (pair.first );
655
+ VLOG (3 ) << " Split " << (is_persistable ? " persistable" : " no persistable" )
656
+ << " data (" << pair.first << " ), dim:" << pair.second .dims ()
657
+ << " , place: " << pair.second .place ();
646
658
auto lod_tensors = pair.second .SplitLoDTensor (member_->places_ );
647
- if (member_->places_ .size () != lod_tensors. size ()) {
648
- bool is_cpu_place = platform::is_cpu_place (member_-> places_ . front ());
659
+ bool is_cpu_place = platform::is_cpu_place (member_->places_ .front ());
660
+ if (!is_persistable && num_places != lod_tensors. size ()) {
649
661
auto error_info = string::Sprintf (
650
- " The number(%d) of samples of "
651
- " current batch is less than the count(%d) of "
652
- " devices(%s), currently, it is not allowed. " ,
653
- lod_tensors.size (), member_->places_ .size (),
662
+ " The number(%d) of samples[%s] of current batch is less than the "
663
+ " count(%d) of devices(%s), currently, it is not allowed. " ,
664
+ lod_tensors.size (), pair.first , num_places,
654
665
(is_cpu_place ? " CPU" : " GPU" ));
655
666
if (is_cpu_place) {
656
667
error_info +=
657
668
" You should set the environment variable CPU_NUM in the system "
658
669
" to determine the number of devices you need." ;
659
670
}
660
671
PADDLE_THROW (error_info);
672
+ } else if (is_persistable) {
673
+ if (lod_tensors.size () == 1 ) {
674
+ lod_tensors.reserve (num_places);
675
+ auto &tensor = lod_tensors.front ();
676
+ PADDLE_ENFORCE_EQ (tensor.dims (), pair.second .dims (),
677
+ " The dim doesn't match." );
678
+ PADDLE_ENFORCE_EQ (tensor.place (), member_->places_ .at (0 ),
679
+ " The place doesn't match." );
680
+ for (size_t i = 1 ; i < num_places; ++i) {
681
+ lod_tensors.emplace_back ();
682
+ auto &tmp = lod_tensors.back ();
683
+ framework::TensorCopy (pair.second , member_->places_ .at (i), &tmp);
684
+ }
685
+ }
686
+ if (lod_tensors.size () != num_places) {
687
+ auto error_info = string::Sprintf (
688
+ " The number(%d) of samples[%s] of the current batch does not match "
689
+ " the count(%d) of devices(%s). Because that %s is a persistable "
690
+ " variable, you can feed just one sample, in that case, the input "
691
+ " sample will be copied in %d copies and be sent to different "
692
+ " places separately. If you need that different place has different "
693
+ " value, you should feed %d samples." ,
694
+ lod_tensors.size (), pair.first , num_places,
695
+ (is_cpu_place ? " CPU" : " GPU" ), pair.first , num_places, num_places);
696
+ PADDLE_THROW (error_info);
697
+ }
661
698
}
699
+ PADDLE_ENFORCE_EQ (
700
+ lod_tensors.size (), num_places,
701
+ " The number(%d) of samples of the current batch does not match the "
702
+ " count(%d) of devices." ,
703
+ lod_tensors.size (), num_places);
662
704
for (size_t j = 0 ; j < member_->places_ .size (); ++j) {
663
705
// TODO(panxy0718): Do I need to delete this var?
664
706
auto t =
0 commit comments