File tree Expand file tree Collapse file tree 4 files changed +39
-18
lines changed
Expand file tree Collapse file tree 4 files changed +39
-18
lines changed Original file line number Diff line number Diff line change @@ -35,12 +35,12 @@ class ModuleList : public CloneableModule<ModuleList> {
3535
3636 std::vector<std::shared_ptr<Tensor>> Forward (const std::vector<std::shared_ptr<Tensor>> &input_tensors) override ;
3737
38- auto begin () { return module_list_. begin (); }
39- auto end () { return module_list_. end (); }
40- auto begin () const { return module_list_. begin (); }
41- auto end () const { return module_list_. end (); }
38+ std::vector<std::shared_ptr<Module>>::iterator begin ();
39+ std::vector<std::shared_ptr<Module>>::iterator end ();
40+ std::vector<std::shared_ptr<Module>>::const_iterator begin () const ;
41+ std::vector<std::shared_ptr<Module>>::const_iterator end () const ;
4242
43- std::shared_ptr<Module> &operator [](std::size_t idx) { return module_list_. at (idx); }
43+ std::shared_ptr<Module> &operator [](std::size_t idx);
4444
4545private:
4646 std::vector<std::shared_ptr<Module>> module_list_;
Original file line number Diff line number Diff line change @@ -22,19 +22,19 @@ class PipelineStage {
2222 std::vector<std::shared_ptr<Tensor>> ForwardOneChunk (const std::vector<std::shared_ptr<Tensor>> &inputs,
2323 int local_chunk_idx = 0 );
2424
25- bool IsFirstStage () const { return stage_index_ == 0 ; }
26- bool IsLastStage () const { return stage_index_ == num_stages_ - 1 ; }
27-
28- int stage_index () const { return stage_index_; }
29- int prev_rank () const { return prev_rank_; }
30- int next_rank () const { return next_rank_; }
31- int num_stages () const { return num_stages_; }
32-
33- const Device *device () const { return device_; }
34- const std::vector<std::vector<int64_t >> &recv_shape () const { return recv_shape_; }
35- std::shared_ptr<Optimizer> optimizer () { return optimizer_; }
36- const auto &chunks () { return chunks_; }
37- auto *mutable_chunks () { return &chunks_; }
25+ bool IsFirstStage () const ;
26+ bool IsLastStage () const ;
27+
28+ int stage_index () const ;
29+ int prev_rank () const ;
30+ int next_rank () const ;
31+ int num_stages () const ;
32+
33+ const Device *device () const ;
34+ const std::vector<std::vector<int64_t >> &recv_shape () const ;
35+ std::shared_ptr<Optimizer> optimizer ();
36+ const std::vector<std::shared_ptr<Module>> &chunks ();
37+ std::vector<std::shared_ptr<Module>> *mutable_chunks ();
3838
3939private:
4040 int stage_index_ = -1 ;
Original file line number Diff line number Diff line change @@ -41,4 +41,12 @@ ModuleList::ModuleList(std::vector<std::shared_ptr<Module>> &&layers)
4141std::vector<std::shared_ptr<Tensor>> ModuleList::Forward (const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
4242 LOG (FATAL) << " Not implemented" ;
4343}
44+
45+ std::vector<std::shared_ptr<Module>>::iterator ModuleList::begin () { return module_list_.begin (); }
46+ std::vector<std::shared_ptr<Module>>::iterator ModuleList::end () { return module_list_.end (); }
47+ std::vector<std::shared_ptr<Module>>::const_iterator ModuleList::begin () const { return module_list_.begin (); }
48+ std::vector<std::shared_ptr<Module>>::const_iterator ModuleList::end () const { return module_list_.end (); }
49+
50+ std::shared_ptr<Module> &ModuleList::operator [](std::size_t idx) { return module_list_.at (idx); }
51+
4452} // namespace infini_train::nn
Original file line number Diff line number Diff line change @@ -28,4 +28,17 @@ std::vector<std::shared_ptr<Tensor>> PipelineStage::ForwardOneChunk(const std::v
2828 return chunks_[local_chunk_idx]->Forward (inputs);
2929}
3030
31+ bool PipelineStage::IsFirstStage () const { return stage_index_ == 0 ; }
32+ bool PipelineStage::IsLastStage () const { return stage_index_ == num_stages_ - 1 ; }
33+
34+ int PipelineStage::stage_index () const { return stage_index_; }
35+ int PipelineStage::prev_rank () const { return prev_rank_; }
36+ int PipelineStage::next_rank () const { return next_rank_; }
37+ int PipelineStage::num_stages () const { return num_stages_; }
38+
39+ const Device *PipelineStage::device () const { return device_; }
40+ const std::vector<std::vector<int64_t >> &PipelineStage::recv_shape () const { return recv_shape_; }
41+ std::shared_ptr<Optimizer> PipelineStage::optimizer () { return optimizer_; }
42+ const std::vector<std::shared_ptr<Module>> &PipelineStage::chunks () { return chunks_; }
43+ std::vector<std::shared_ptr<Module>> *PipelineStage::mutable_chunks () { return &chunks_; }
3144} // namespace infini_train::nn::parallel
You can’t perform that action at this time.
0 commit comments