Skip to content

Commit 83d11cc

Browse files
JYMiracle305kilinchange
authored andcommitted
fix: resolve review comments
1 parent a2dca69 commit 83d11cc

File tree

4 files changed

+39
-18
lines changed

4 files changed

+39
-18
lines changed

infini_train/include/nn/modules/container.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ class ModuleList : public CloneableModule<ModuleList> {
3535

3636
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
3737

38-
auto begin() { return module_list_.begin(); }
39-
auto end() { return module_list_.end(); }
40-
auto begin() const { return module_list_.begin(); }
41-
auto end() const { return module_list_.end(); }
38+
std::vector<std::shared_ptr<Module>>::iterator begin();
39+
std::vector<std::shared_ptr<Module>>::iterator end();
40+
std::vector<std::shared_ptr<Module>>::const_iterator begin() const;
41+
std::vector<std::shared_ptr<Module>>::const_iterator end() const;
4242

43-
std::shared_ptr<Module> &operator[](std::size_t idx) { return module_list_.at(idx); }
43+
std::shared_ptr<Module> &operator[](std::size_t idx);
4444

4545
private:
4646
std::vector<std::shared_ptr<Module>> module_list_;

infini_train/include/nn/parallel/pp/pipeline_stage.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@ class PipelineStage {
2222
std::vector<std::shared_ptr<Tensor>> ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs,
2323
int local_chunk_idx = 0);
2424

25-
bool IsFirstStage() const { return stage_index_ == 0; }
26-
bool IsLastStage() const { return stage_index_ == num_stages_ - 1; }
27-
28-
int stage_index() const { return stage_index_; }
29-
int prev_rank() const { return prev_rank_; }
30-
int next_rank() const { return next_rank_; }
31-
int num_stages() const { return num_stages_; }
32-
33-
const Device *device() const { return device_; }
34-
const std::vector<std::vector<int64_t>> &recv_shape() const { return recv_shape_; }
35-
std::shared_ptr<Optimizer> optimizer() { return optimizer_; }
36-
const auto &chunks() { return chunks_; }
37-
auto *mutable_chunks() { return &chunks_; }
25+
bool IsFirstStage() const;
26+
bool IsLastStage() const;
27+
28+
int stage_index() const;
29+
int prev_rank() const;
30+
int next_rank() const;
31+
int num_stages() const;
32+
33+
const Device *device() const;
34+
const std::vector<std::vector<int64_t>> &recv_shape() const;
35+
std::shared_ptr<Optimizer> optimizer();
36+
const std::vector<std::shared_ptr<Module>> &chunks();
37+
std::vector<std::shared_ptr<Module>> *mutable_chunks();
3838

3939
private:
4040
int stage_index_ = -1;

infini_train/src/nn/modules/container.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,12 @@ ModuleList::ModuleList(std::vector<std::shared_ptr<Module>> &&layers)
4141
std::vector<std::shared_ptr<Tensor>> ModuleList::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
4242
LOG(FATAL) << "Not implemented";
4343
}
44+
45+
std::vector<std::shared_ptr<Module>>::iterator ModuleList::begin() { return module_list_.begin(); }
46+
std::vector<std::shared_ptr<Module>>::iterator ModuleList::end() { return module_list_.end(); }
47+
std::vector<std::shared_ptr<Module>>::const_iterator ModuleList::begin() const { return module_list_.begin(); }
48+
std::vector<std::shared_ptr<Module>>::const_iterator ModuleList::end() const { return module_list_.end(); }
49+
50+
std::shared_ptr<Module> &ModuleList::operator[](std::size_t idx) { return module_list_.at(idx); }
51+
4452
} // namespace infini_train::nn

infini_train/src/nn/parallel/pp/pipeline_stage.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,17 @@ std::vector<std::shared_ptr<Tensor>> PipelineStage::ForwardOneChunk(const std::v
2828
return chunks_[local_chunk_idx]->Forward(inputs);
2929
}
3030

31+
bool PipelineStage::IsFirstStage() const { return stage_index_ == 0; }
32+
bool PipelineStage::IsLastStage() const { return stage_index_ == num_stages_ - 1; }
33+
34+
int PipelineStage::stage_index() const { return stage_index_; }
35+
int PipelineStage::prev_rank() const { return prev_rank_; }
36+
int PipelineStage::next_rank() const { return next_rank_; }
37+
int PipelineStage::num_stages() const { return num_stages_; }
38+
39+
const Device *PipelineStage::device() const { return device_; }
40+
const std::vector<std::vector<int64_t>> &PipelineStage::recv_shape() const { return recv_shape_; }
41+
std::shared_ptr<Optimizer> PipelineStage::optimizer() { return optimizer_; }
42+
const std::vector<std::shared_ptr<Module>> &PipelineStage::chunks() { return chunks_; }
43+
std::vector<std::shared_ptr<Module>> *PipelineStage::mutable_chunks() { return &chunks_; }
3144
} // namespace infini_train::nn::parallel

0 commit comments

Comments
 (0)