@@ -148,6 +148,8 @@ void Train(const nn::parallel::Rank &rank) {
148148 pp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (
149149 GetPipelineParallelProcessGroupName (rank.thread_rank ()), GetPipelineParallelGroupRanks (pp_world_size));
150150 pp_rank = pp_pg->GetGroupRank (rank.thread_rank ());
151+
152+ nn::parallel::pp_rank = pp_rank;
151153 }
152154 } else {
153155 device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance ()->GetDefaultDevice ()
@@ -243,8 +245,9 @@ void Train(const nn::parallel::Rank &rank) {
243245 CHECK_EQ ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0 )
244246 << " FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size)
245247 << " ) must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << " )" ;
246- auto shapes = std::vector<std::vector<int64_t >>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches,
247- FLAGS_sequence_length, model->GetConfig ()[" n_embd" ]}};
248+
249+ auto shapes = std::vector<std::vector<int64_t >>{
250+ {(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches, FLAGS_sequence_length, model_config.n_embd }};
248251
249252 model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, FLAGS_num_microbatches, shapes,
250253 pp_rank, optimizer_factory);
@@ -298,9 +301,9 @@ void Train(const nn::parallel::Rank &rank) {
298301 x = std::make_shared<Tensor>(x->To (device));
299302 y = std::make_shared<Tensor>(y->To (device));
300303
304+ // FIXME(jym): without gradient accumulation
301305 if (pp_world_size > 1 ) {
302306 lossf = model->TrainStep ({x}, {y}, loss_fn);
303-
304307 auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t >{}, DataType::kFLOAT32 );
305308 static_cast <float *>(loss_tensor->DataPtr ())[0 ] = lossf;
306309 auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To (device));
0 commit comments