Skip to content

Commit a2dca69

Browse files
JYMiracle305kilinchange
authored andcommitted
feat: add virtual_pipeline_parallel test scripts
1 parent 758c6e2 commit a2dca69

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

infini_train/src/nn/parallel/pp/pipeline_schedule.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,10 @@ float PipelineSchedule::StepMicroBatches(const std::vector<std::shared_ptr<Tenso
194194

195195
auto schedule = PipelineParallelScheduler::GenerateGPipeSchedule(n, num_stages, vpp_size);
196196

197-
if (stage_idx == 0) {
197+
static bool has_printed = false;
198+
if (!has_printed && stage_idx == 0) {
198199
PrintScheduleTable(schedule, n, num_stages, vpp_size);
200+
has_printed = true;
199201
}
200202

201203
float total_loss = 0.0f;

scripts/test_config.json

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,30 @@
141141
},
142142
{
143143
"id": "7",
144+
"args": {
145+
"dtype": "float32",
146+
"nthread_per_process": 4,
147+
"num_iteration": 10,
148+
"batch_size": 10,
149+
"total_batch_size": 5120,
150+
"pipeline_parallel": 4,
151+
"virtual_pipeline_parallel": 2
152+
}
153+
},
154+
{
155+
"id": "7_bfloat16",
156+
"args": {
157+
"dtype": "bfloat16",
158+
"nthread_per_process": 4,
159+
"num_iteration": 10,
160+
"batch_size": 10,
161+
"total_batch_size": 5120,
162+
"pipeline_parallel": 4,
163+
"virtual_pipeline_parallel": 2
164+
}
165+
},
166+
{
167+
"id": "8",
144168
"args": {
145169
"dtype": "float32",
146170
"nthread_per_process": 8,
@@ -149,11 +173,12 @@
149173
"total_batch_size": 5120,
150174
"tensor_parallel": 2,
151175
"sequence_parallel": true,
152-
"pipeline_parallel": 2
176+
"pipeline_parallel": 2,
177+
"virtual_pipeline_parallel": 2
153178
}
154179
},
155180
{
156-
"id": "7_bfloat16",
181+
"id": "8_bfloat16",
157182
"args": {
158183
"dtype": "bfloat16",
159184
"nthread_per_process": 8,
@@ -162,7 +187,8 @@
162187
"total_batch_size": 5120,
163188
"tensor_parallel": 2,
164189
"sequence_parallel": true,
165-
"pipeline_parallel": 2
190+
"pipeline_parallel": 2,
191+
"virtual_pipeline_parallel": 2
166192
}
167193
}
168194
]

0 commit comments

Comments
 (0)