diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 4a34c464..c43a04d2 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -68,6 +68,9 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +// precision check +DEFINE_int32(precision_check, 0, "precision check level: 0=off, 1=module, 2=function"); +DEFINE_bool(precision_check_all_ranks, false, "enable precision check for all ranks (default: rank 0 only)"); using namespace infini_train; @@ -297,9 +300,9 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward"; // (bs, seq_len, vocab_size) - auto logits = model->Forward({x, y})[0]; + auto logits = (*model)({x, y})[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward"; - auto loss = loss_fn->Forward({logits, y})[0]; + auto loss = (*loss_fn)({logits, y})[0]; // FIXME(jym): verify gradient accumulation precision loss = loss / grad_accum_steps; @@ -364,7 +367,8 @@ int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); + FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel, FLAGS_precision_check, + FLAGS_precision_check_all_ranks); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index 69e4278e..18e07dca 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -47,7 +47,7 @@ NewGELU::Forward(const std::vector> &x) { } CausalSelfAttention::CausalSelfAttention(const GPT2Config &config) - : config_(config), n_head_(config.n_head), n_embd_(config.n_embd) { + : CloneableModule(kType), config_(config), n_head_(config.n_head), n_embd_(config.n_embd) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); CHECK_EQ(config.n_embd % config.n_head, 0); CHECK_EQ(n_head_ % tp_world_size, 0) << "n_head must be divisible by TP world size"; @@ -89,7 +89,7 @@ CausalSelfAttention::Forward(const std::vector ColumnParallelLinear(C, 3*C) -> (B, T, 3 * local_C) // -> Split -> (3, B, T, local_C) - auto qkv = modules_[kCAttnLayerName]->Forward(x)[0]->Split(local_C, 2); + auto qkv = (*modules_[kCAttnLayerName])(x)[0]->Split(local_C, 2); // (B, T, local_C) auto q = qkv[0]; @@ -120,12 +120,12 @@ CausalSelfAttention::Forward(const std::vector RowParallelLinear(n_embd, n_embd) -> (B, T, C) - y = modules_[kCProjLayerName]->Forward({y})[0]; + y = (*modules_[kCProjLayerName])({y})[0]; // (B, T, C) == (bs, seq_len, n_embd) return {y}; } -MLP::MLP(const GPT2Config &config) { +MLP::MLP(const GPT2Config &config) : CloneableModule(kType) { // c_fc: ColumnParallel (input full, output parallel) modules_[kCFcLayerName] = std::make_shared( /*in_features=*/config.n_embd, /*out_features=*/4 * config.n_embd, @@ -150,16 +150,16 @@ MLP::MLP(const GPT2Config &config) { std::vector> MLP::Forward(const std::vector> &x) { // (B, T, C) -> ColumnParallelLinear(C, 4 * C) -> (B, T, 4 * C_local) - auto x1 = modules_[kCFcLayerName]->Forward(x); + auto x1 = (*modules_[kCFcLayerName])(x); // (B, T, 4 * C_local) -> GELU -> (B, T, 4 * C_local) - auto x2 = modules_[kGeluLayerName]->Forward(x1); + auto x2 = (*modules_[kGeluLayerName])(x1); // (B, T, 4 * C_local) -> RowParallelLinear(4 * C, C) -> (B, T, C) - auto x3 = modules_[kCProjLayerName]->Forward(x2); + auto x3 = (*modules_[kCProjLayerName])(x2); // (B, T, C) return x3; } -Block::Block(const GPT2Config &config) { +Block::Block(const GPT2Config &config) : CloneableModule(kType) { modules_[kLn1LayerName] = std::make_shared(std::vector{config.n_embd}); modules_[kAttnLayerName] = std::make_shared(config); modules_[kLn2LayerName] = std::make_shared(std::vector{config.n_embd}); @@ -170,15 +170,15 @@ std::vector> Block::Forward(const std::vector> &x) { // (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> attention -> (bs, seq_len, n_embd) // -> Add -> (bs, seq_len, n_embd) - auto x1 = x[0] + modules_[kAttnLayerName]->Forward(modules_[kLn1LayerName]->Forward(x))[0]; + auto x1 = x[0] + (*modules_[kAttnLayerName])((*modules_[kLn1LayerName])(x))[0]; // (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd) // -> Add -> (bs, seq_len, n_embd) - auto x2 = x1 + modules_[kMlpLayerName]->Forward(modules_[kLn2LayerName]->Forward({x1}))[0]; + auto x2 = x1 + (*modules_[kMlpLayerName])((*modules_[kLn2LayerName])({x1}))[0]; // (bs, seq_len, n_embd) return {x2}; } -GPT2FirstStage::GPT2FirstStage(const GPT2Config &config) : config_(config) { +GPT2FirstStage::GPT2FirstStage(const GPT2Config &config) : CloneableModule(kType), config_(config) { modules_[kWTELayerName] = std::make_shared( config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); @@ -207,15 +207,15 @@ GPT2FirstStage::Forward(const std::vector> auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = modules_[kWTELayerName]->Forward({x1})[0]; + auto tok_emb = (*modules_[kWTELayerName])({x1})[0]; // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = modules_[kWPELayerName]->Forward({pos})[0]; + auto pos_emb = (*modules_[kWPELayerName])({pos})[0]; // (B, T, C) return {tok_emb + pos_emb}; } -GPT2Chunk::GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer) : config_(config) { +GPT2Chunk::GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer) : CloneableModule(kType), config_(config) { std::vector> h; for (int64_t i = start_layer; i < end_layer; ++i) { auto layer = std::make_shared(config); @@ -228,11 +228,11 @@ std::vector> GPT2Chunk::Forward(const std::vector> &x) { auto x1 = x[0]; // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = h->Forward({x1})[0]; } + for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1})[0]; } return {x1}; } -GPT2LastStage::GPT2LastStage(const GPT2Config &config) : config_(config) { +GPT2LastStage::GPT2LastStage(const GPT2Config &config) : CloneableModule(kType), config_(config) { modules_[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); // don't init this one, we will tie weights modules_[kLMHeadLayerName] = std::make_shared( @@ -248,15 +248,15 @@ GPT2LastStage::GPT2LastStage(const GPT2Config &config) : config_(config) { std::vector> GPT2LastStage::Forward(const std::vector> &x) { // (B, T, C) -> Layernorm -> (B, T, C) - auto x1 = modules_[kLnFLayerName]->Forward(x); + auto x1 = (*modules_[kLnFLayerName])(x); // TODO(dcj): add inference-time mini-optimization // (B, T, C) -> Linear(C, V) -> (B, T, V) - return modules_[kLMHeadLayerName]->Forward(x1); + return (*modules_[kLMHeadLayerName])(x1); } GPT2::GPT2(const GPT2Config &config) - : config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + : CloneableModule(kType), config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, nn::parallel::global::GetVirtualPipelineParallelSize())) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); @@ -316,11 +316,11 @@ GPT2::GPT2(const GPT2Config &config) std::vector> GPT2::Forward(const std::vector> &x) { - auto x1 = modules_[kPPFirstStageName]->Forward(x); + auto x1 = (*modules_[kPPFirstStageName])(x); for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { - x1 = modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)]->Forward(x1); + x1 = (*modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)])(x1); } - return modules_[kPPLastStageName]->Forward(x1); + return (*modules_[kPPLastStageName])(x1); } std::shared_ptr GPT2::FromPretrained(ModelType model_type) { diff --git a/example/gpt2/net.h b/example/gpt2/net.h index f52e4d4e..4faf5451 100644 --- a/example/gpt2/net.h +++ b/example/gpt2/net.h @@ -23,12 +23,16 @@ struct GPT2Config { class NewGELU : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "NewGELU"; + NewGELU() : CloneableModule(kType) {} + std::vector> Forward(const std::vector> &x) override; }; class CausalSelfAttention : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "CausalSelfAttention"; static constexpr char kCAttnLayerName[] = "c_attn"; static constexpr char kCProjLayerName[] = "c_proj"; @@ -49,6 +53,7 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "MLP"; static constexpr char kCFcLayerName[] = "c_fc"; static constexpr char kGeluLayerName[] = "gelu"; static constexpr char kCProjLayerName[] = "c_proj"; @@ -61,6 +66,7 @@ class MLP : public infini_train::nn::CloneableModule { class Block : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "Block"; static constexpr char kLn1LayerName[] = "ln_1"; static constexpr char kAttnLayerName[] = "attn"; static constexpr char kLn2LayerName[] = "ln_2"; @@ -74,6 +80,7 @@ class Block : public infini_train::nn::CloneableModule { class GPT2FirstStage : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "GPT2FirstStage"; static constexpr char kWTELayerName[] = "wte"; static constexpr char kWPELayerName[] = "wpe"; @@ -88,6 +95,7 @@ class GPT2FirstStage : public infini_train::nn::CloneableModule class GPT2Chunk : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "GPT2Chunk"; static constexpr char kHLayerName[] = "h"; GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer); @@ -101,6 +109,7 @@ class GPT2Chunk : public infini_train::nn::CloneableModule { class GPT2LastStage : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "GPT2LastStage"; static constexpr char kLnFLayerName[] = "ln_f"; static constexpr char kLMHeadLayerName[] = "lm_head"; @@ -115,6 +124,7 @@ class GPT2LastStage : public infini_train::nn::CloneableModule { class GPT2 : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "GPT2"; static constexpr char kTransformerLayerName[] = "transformer"; enum class ModelType : int8_t { diff --git a/example/llama3/main.cc b/example/llama3/main.cc index fdea2162..c6db113e 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -66,6 +66,9 @@ DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +// precision check +DEFINE_int32(precision_check, 0, "precision check level: 0=off, 1=module, 2=function"); +DEFINE_bool(precision_check_all_ranks, false, "enable precision check for all ranks (default: rank 0 only)"); using namespace infini_train; @@ -273,9 +276,9 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward"; // (bs, seq_len, vocab_size) - auto logits = model->Forward({x, y})[0]; + auto logits = (*model)({x, y})[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward"; - auto loss = loss_fn->Forward({logits, y})[0]; + auto loss = (*loss_fn)({logits, y})[0]; // FIXME(jym): verify gradient accumulation precision loss = loss / grad_accum_steps; @@ -340,7 +343,8 @@ int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); + FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel, FLAGS_precision_check, + FLAGS_precision_check_all_ranks); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index a70a811a..12bcf0ed 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -127,7 +127,7 @@ std::vector> SwiGLU::Forward(const std::vector(std::vector{dim}, DataType::kFLOAT32, device)->RequiresGrad(); nn::init::Ones(parameters_[kParamWeightName]); @@ -140,7 +140,7 @@ std::vector> RMSNorm::Forward(const std::vector> CausalSelfAttention::Forward(const std::vec CHECK(freqs_cis != nullptr) << "freqs_cis is null."; // (B, T, C) -> (B, T, (H + 2 * n_kv_head) * D) - auto qkv = modules_[kCAttnLayerName]->Forward({x[0]})[0]; + auto qkv = (*modules_[kCAttnLayerName])({x[0]})[0]; // NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear const auto T = qkv->Dims()[1]; // NOTE(zbl): torch script uses torch.split({...}, dim) to split tensors into sub-tensors in different sizes @@ -240,12 +240,12 @@ std::vector> CausalSelfAttention::Forward(const std::vec y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); // output projection // (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C) - y = modules_[kCProjLayerName]->Forward({y})[0]; + y = (*modules_[kCProjLayerName])({y})[0]; // (B, H, C) == (bs, seq_len, n_embd) return {y}; } -MLP::MLP(const LLaMA3Config &config) { +MLP::MLP(const LLaMA3Config &config) : CloneableModule(kType) { hidden_dim_ = 4 * config.n_embd; hidden_dim_ = int(2 * hidden_dim_ / 3); // use custom dim factor multiplier @@ -286,20 +286,20 @@ MLP::MLP(const LLaMA3Config &config) { std::vector> MLP::Forward(const std::vector> &x) { // (bs, seq_len, n_embd) -> Linear(n_embd, hidden_dim) -> (bs, seq_len, hidden_dim) - auto x1 = modules_[kCFcLayerName]->Forward(x)[0]; + auto x1 = (*modules_[kCFcLayerName])(x)[0]; // (bs, seq_len, n_embd) -> Linear(n_embd, hidden_dim) -> (bs, seq_len, hidden_dim) - auto x2 = modules_[kCFc2LayerName]->Forward(x)[0]; + auto x2 = (*modules_[kCFc2LayerName])(x)[0]; // (bs, seq_len, hidden_dim) -> SwiGLU -> (bs, seq_len, hidden_dim) - x2 = modules_[kSiluLayerName]->Forward({x2})[0]; + x2 = (*modules_[kSiluLayerName])({x2})[0]; // (bs, seq_len, hidden_dim) auto x3 = x1 * x2; // (bs, seq_len, hidden_dim) -> Linear(hidden_dim, n_embd) -> (bs, seq_len, n_embd) - auto x4 = modules_[kCProjLayerName]->Forward({x3}); + auto x4 = (*modules_[kCProjLayerName])({x3}); // (bs, seq_len, n_embd) return x4; } -Block::Block(const LLaMA3Config &config) { +Block::Block(const LLaMA3Config &config) : CloneableModule(kType) { modules_[kLn1LayerName] = std::make_shared(config.n_embd, config.norm_eps); modules_[kAttnLayerName] = std::make_shared(config); modules_[kLn2LayerName] = std::make_shared(config.n_embd, config.norm_eps); @@ -314,27 +314,27 @@ std::vector> Block::Forward(const std::vector RMSNorm -> (bs, seq_len, n_embd) -> attention -> (bs, seq_len, n_embd) // -> Add -> (bs, seq_len, n_embd) auto x1 = x[0] - + modules_[kAttnLayerName]->Forward(std::vector>{ - modules_[kLn1LayerName]->Forward({x[0]})[0], freqs_cis, start_pos, mask})[0]; + + (*modules_[kAttnLayerName])(std::vector>{ + (*modules_[kLn1LayerName])({x[0]})[0], freqs_cis, start_pos, mask})[0]; // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd) // -> Add -> (bs, seq_len, n_embd) auto x2 = x1 - + modules_[kMlpLayerName]->Forward( - std::vector>(modules_[kLn2LayerName]->Forward({x1})))[0]; + + (*modules_[kMlpLayerName])( + std::vector>((*modules_[kLn2LayerName])({x1})))[0]; // (bs, seq_len, n_embd) return {x2}; } -LLaMA3FirstStage::LLaMA3FirstStage(const LLaMA3Config &config) : config_(config) { +LLaMA3FirstStage::LLaMA3FirstStage(const LLaMA3Config &config) : CloneableModule(kType), config_(config) { modules_[LLaMA3FirstStage::kWTELayerName] = std::make_shared( config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); } std::vector> LLaMA3FirstStage::Forward(const std::vector> &x) { - return modules_[LLaMA3FirstStage::kWTELayerName]->Forward(x); + return (*modules_[LLaMA3FirstStage::kWTELayerName])(x); } -LLaMA3Chunk::LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer) : config_(config) { +LLaMA3Chunk::LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer) : CloneableModule(kType), config_(config) { std::vector> h; for (int64_t i = start_layer; i < end_layer; ++i) { auto layer = std::make_shared(config); @@ -368,12 +368,12 @@ std::vector> LLaMA3Chunk::Forward(const std::vector transformer -> (bs, seq_len, n_embd) for (auto &h : *std::dynamic_pointer_cast(modules_[LLaMA3Chunk::kHLayerName])) { - x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; + x1 = (*h)({x1, freqs_view, start_pos_ptr, mask})[0]; } return {x1}; } -LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : config_(config) { +LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : CloneableModule(kType), config_(config) { modules_[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); // NOTE(zbl): weight-tying is possible but torch script did not do so modules_[kLMHeadLayerName] = std::make_shared( @@ -388,15 +388,15 @@ LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : config_(config) { std::vector> LLaMA3LastStage::Forward(const std::vector> &x) { // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x1 = modules_[kLnFLayerName]->Forward(x); + auto x1 = (*modules_[kLnFLayerName])(x); // TODO(zbl): add inference-time mini-optimization // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - return modules_[kLMHeadLayerName]->Forward(x1); + return (*modules_[kLMHeadLayerName])(x1); } LLaMA3::LLaMA3(const LLaMA3Config &config) - : config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + : CloneableModule(kType), config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, nn::parallel::global::GetVirtualPipelineParallelSize())) { std::unordered_map> transformer; @@ -439,11 +439,11 @@ LLaMA3::LLaMA3(const LLaMA3Config &config) } std::vector> LLaMA3::Forward(const std::vector> &x) { - auto x1 = modules_[kPPFirstStageName]->Forward({x[0]}); + auto x1 = (*modules_[kPPFirstStageName])({x[0]}); for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { - x1 = modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)]->Forward(x1); + x1 = (*modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)])(x1); } - return modules_[kPPLastStageName]->Forward(x1); + return (*modules_[kPPLastStageName])(x1); } std::shared_ptr LLaMA3::FromPretrained(ModelType model_type) { diff --git a/example/llama3/net.h b/example/llama3/net.h index 9bd7f9da..034aa9e8 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -42,6 +42,9 @@ struct LLaMA3Config { class SwiGLU : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "SwiGLU"; + SwiGLU() : CloneableModule(kType) {} + std::vector> Forward(const std::vector> &x) override; }; @@ -49,6 +52,7 @@ class SwiGLU : public infini_train::nn::CloneableModule { // TODO(zbl): implement fused kernel class RMSNorm : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "RMSNorm"; static constexpr char kParamWeightName[] = "weight"; explicit RMSNorm(int64_t dim, float eps = 1e-6f, @@ -63,6 +67,7 @@ class RMSNorm : public infini_train::nn::CloneableModule { class CausalSelfAttention : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "CausalSelfAttention"; static constexpr char kCAttnLayerName[] = "c_attn"; static constexpr char kCProjLayerName[] = "c_proj"; @@ -82,6 +87,7 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "MLP"; static constexpr char kCFcLayerName[] = "c_fc"; static constexpr char kCFc2LayerName[] = "c_fc2"; static constexpr char kSiluLayerName[] = "silu"; @@ -98,6 +104,7 @@ class MLP : public infini_train::nn::CloneableModule { class Block : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "Block"; static constexpr char kLn1LayerName[] = "ln_1"; static constexpr char kAttnLayerName[] = "attn"; static constexpr char kLn2LayerName[] = "ln_2"; @@ -111,6 +118,7 @@ class Block : public infini_train::nn::CloneableModule { class LLaMA3FirstStage : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "LLaMA3FirstStage"; static constexpr char kWTELayerName[] = "wte"; explicit LLaMA3FirstStage(const LLaMA3Config &config); @@ -124,6 +132,7 @@ class LLaMA3FirstStage : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "LLaMA3Chunk"; static constexpr char kHLayerName[] = "h"; static constexpr char kFreqsCisName[] = "freqs_cis"; @@ -138,6 +147,7 @@ class LLaMA3Chunk : public infini_train::nn::CloneableModule { class LLaMA3LastStage : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "LLaMA3LastStage"; static constexpr char kLnFLayerName[] = "ln_f"; static constexpr char kLMHeadLayerName[] = "lm_head"; @@ -152,6 +162,7 @@ class LLaMA3LastStage : public infini_train::nn::CloneableModule { public: + static constexpr char kType[] = "LLaMA3"; static constexpr char kTransformerLayerName[] = "transformer"; enum class ModelType : int8_t { diff --git a/example/mnist/net.cc b/example/mnist/net.cc index 4ac83613..501fee7e 100644 --- a/example/mnist/net.cc +++ b/example/mnist/net.cc @@ -25,7 +25,7 @@ MNIST::MNIST() { std::vector> MNIST::Forward(const std::vector> &x) { CHECK_EQ(x.size(), 1); - auto x1 = modules_["sequential"]->Forward(x); - auto x2 = modules_["linear2"]->Forward(x1); + auto x1 = (*modules_["sequential"])(x); + auto x2 = (*modules_["linear2"])(x1); return x2; } diff --git a/infini_train/include/autograd/accumulate.h b/infini_train/include/autograd/accumulate.h index f3519cb1..a8e41e67 100644 --- a/infini_train/include/autograd/accumulate.h +++ b/infini_train/include/autograd/accumulate.h @@ -18,6 +18,8 @@ class AccumulateGrad final : public Function { std::vector> Backward(const std::vector> &) override; + std::shared_ptr tensor() const { return tensor_; } + private: std::shared_ptr tensor_ = nullptr; float learning_rate_ = 1.0f; diff --git a/infini_train/include/autograd/function.h b/infini_train/include/autograd/function.h index bbc091d4..defbf907 100644 --- a/infini_train/include/autograd/function.h +++ b/infini_train/include/autograd/function.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -9,8 +10,17 @@ class Tensor; } namespace infini_train::autograd { +class HookHandle; + class Function : public std::enable_shared_from_this { public: + using FunctionForwardPreHook = std::function>&)>; + using FunctionForwardPostHook = std::function>&, + const std::vector>&)>; + using FunctionBackwardPreHook = std::function>&)>; + using FunctionBackwardPostHook = std::function>&, + const std::vector>&)>; + static constexpr char kUndefinedType[] = "Undefined"; Function() : type_(kUndefinedType) {} @@ -28,6 +38,13 @@ class Function : public std::enable_shared_from_this { void IncreaseDependenciesNumber(); + std::shared_ptr RegisterForwardPreHook(FunctionForwardPreHook hook); + std::shared_ptr RegisterForwardPostHook(FunctionForwardPostHook hook); + std::shared_ptr RegisterBackwardPreHook(FunctionBackwardPreHook hook); + std::shared_ptr RegisterBackwardPostHook(FunctionBackwardPostHook hook); + + const std::string& type() const { return type_; } + protected: std::vector> saved_tensors_; @@ -38,5 +55,10 @@ class Function : public std::enable_shared_from_this { int grad_outputs_reached_ = 0; std::vector> grad_outputs_; const std::string type_ = kUndefinedType; + std::vector forward_pre_hooks_; + std::vector forward_post_hooks_; + std::vector backward_pre_hooks_; + std::vector backward_post_hooks_; + bool precision_check_registered_ = false; }; } // namespace infini_train::autograd diff --git a/infini_train/include/autograd/function_hook.h b/infini_train/include/autograd/function_hook.h index 7d750e03..7d57926f 100644 --- a/infini_train/include/autograd/function_hook.h +++ b/infini_train/include/autograd/function_hook.h @@ -1,6 +1,8 @@ #pragma once +#include #include +#include #include "infini_train/include/nn/parallel/reduce_op_type.h" @@ -13,6 +15,14 @@ class ProcessGroup; } // namespace infini_train namespace infini_train::autograd { +class Function; + +class HookHandle { +public: + virtual ~HookHandle() = default; + virtual void Remove() = 0; +}; + class PostAccumulateGradHook { public: virtual void operator()(const std::shared_ptr &tensor) = 0; @@ -30,4 +40,22 @@ class AllReducePostAccumulateHook : public PostAccumulateGradHook { infini_train::nn::parallel::function::ReduceOpType reduce_op_; const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; }; + +template +class FunctionHookHandleImpl : public HookHandle { +public: + FunctionHookHandleImpl(std::vector* hooks, size_t id) : hooks_(hooks), id_(id) {} + + void Remove() override { + if (!removed_ && hooks_ && id_ < hooks_->size()) { + (*hooks_)[id_] = nullptr; + removed_ = true; + } + } + +private: + std::vector* hooks_; + size_t id_; + bool removed_ = false; +}; } // namespace infini_train::autograd diff --git a/infini_train/include/autograd/tensor_hook.h b/infini_train/include/autograd/tensor_hook.h new file mode 100644 index 00000000..f7fcbe37 --- /dev/null +++ b/infini_train/include/autograd/tensor_hook.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +namespace infini_train { +class Tensor; + +namespace autograd { + +// Hook handle for removing hooks +class HookHandle { +public: + virtual ~HookHandle() = default; + virtual void Remove() = 0; +}; + +// Tensor backward hook: modifies gradient during backward pass +// Returns modified gradient or nullptr to keep original +using TensorBackwardHook = std::function(const std::shared_ptr&)>; + +class TensorBackwardHookHandle : public HookHandle { +public: + TensorBackwardHookHandle(std::vector* hooks, size_t id) + : hooks_(hooks), id_(id) {} + + void Remove() override; + +private: + std::vector* hooks_; + size_t id_; + bool removed_ = false; +}; + +} // namespace autograd +} // namespace infini_train diff --git a/infini_train/include/nn/module_hook.h b/infini_train/include/nn/module_hook.h new file mode 100644 index 00000000..ea3b9219 --- /dev/null +++ b/infini_train/include/nn/module_hook.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +namespace infini_train { +class Tensor; + +namespace nn { +class Module; + +// Forward pre-hook: called before forward pass +// Args: (module, input_tensors) +using ForwardPreHook = std::function>&)>; + +// Forward post-hook: called after forward pass +// Args: (module, input_tensors, output_tensors) +using ForwardPostHook = std::function>&, + const std::vector>&)>; + +// Backward pre-hook: called before backward pass +// Args: (module, grad_output) +using BackwardPreHook = std::function>&)>; + +// Backward post-hook: called after backward pass +// Args: (module, grad_input, grad_output) +using BackwardPostHook = std::function>&, + const std::vector>&)>; + +class ModuleHookHandle { +public: + virtual ~ModuleHookHandle() = default; + virtual void Remove() = 0; +}; + +template +class ModuleHookHandleImpl : public ModuleHookHandle { +public: + ModuleHookHandleImpl(std::vector* hooks, size_t id) : hooks_(hooks), id_(id) {} + + void Remove() override { + if (!removed_ && hooks_ && id_ < hooks_->size()) { + (*hooks_)[id_] = nullptr; + removed_ = true; + } + } + +private: + std::vector* hooks_; + size_t id_; + bool removed_ = false; +}; + +} // namespace nn +} // namespace infini_train diff --git a/infini_train/include/nn/modules/activations.h b/infini_train/include/nn/modules/activations.h index b7435cd9..e47be5a3 100644 --- a/infini_train/include/nn/modules/activations.h +++ b/infini_train/include/nn/modules/activations.h @@ -12,7 +12,8 @@ class Tensor; namespace infini_train::nn { class Sigmoid : public CloneableModule { public: - Sigmoid() = default; + static constexpr char kType[] = "Sigmoid"; + Sigmoid() : CloneableModule(kType) {} std::vector> Forward(const std::vector> &input_tensors) override; }; } // namespace infini_train::nn diff --git a/infini_train/include/nn/modules/container.h b/infini_train/include/nn/modules/container.h index 4da8b3e6..28bceccf 100644 --- a/infini_train/include/nn/modules/container.h +++ b/infini_train/include/nn/modules/container.h @@ -13,6 +13,7 @@ class Tensor; namespace infini_train::nn { class Sequential : public CloneableModule { public: + static constexpr char kType[] = "Sequential"; // TODO(dcj): Use better ctor signature later. explicit Sequential(std::vector> &&layers); @@ -21,6 +22,7 @@ class Sequential : public CloneableModule { class ModuleDict : public CloneableModule { public: + static constexpr char kType[] = "ModuleDict"; // TODO(dcj): in torch, there is a dict with the order of insertion explicit ModuleDict(std::unordered_map> modules); diff --git a/infini_train/include/nn/modules/loss.h b/infini_train/include/nn/modules/loss.h index 5b3ddf25..f0543f53 100644 --- a/infini_train/include/nn/modules/loss.h +++ b/infini_train/include/nn/modules/loss.h @@ -8,7 +8,8 @@ namespace infini_train::nn { class CrossEntropyLoss : public CloneableModule { public: - CrossEntropyLoss() = default; + static constexpr char kType[] = "CrossEntropyLoss"; + CrossEntropyLoss() : CloneableModule(kType) {} std::vector> Forward(const std::vector> &input_tensors) override; }; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 9bc78bcc..266684c5 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -7,6 +7,7 @@ #include #include "infini_train/include/datatype.h" +#include "infini_train/include/nn/module_hook.h" namespace infini_train { class Tensor; @@ -50,6 +51,10 @@ class Module : public std::enable_shared_from_this { std::unordered_map> StateDict() const; + // operator() calls hooks and Forward + std::vector> operator()(const std::vector> &input_tensors); + + // Forward to be overridden by subclasses virtual std::vector> Forward(const std::vector> &input_tensors); virtual float TrainStep(const std::vector> &input_tensors, @@ -66,6 +71,12 @@ class Module : public std::enable_shared_from_this { virtual std::shared_ptr ReplicateForDataParallel(int device_idx) const; + // Hook registration methods + std::shared_ptr RegisterForwardPreHook(ForwardPreHook hook); + std::shared_ptr RegisterForwardPostHook(ForwardPostHook hook); + std::shared_ptr RegisterBackwardPreHook(BackwardPreHook hook); + std::shared_ptr RegisterBackwardPostHook(BackwardPostHook hook); + protected: const Device *device_ = nullptr; const std::string type_ = kUndefinedType; @@ -73,6 +84,12 @@ class Module : public std::enable_shared_from_this { std::unordered_map> parameters_; std::unordered_map> buffers_; + std::vector forward_pre_hooks_; + std::vector forward_post_hooks_; + std::vector backward_pre_hooks_; + std::vector backward_post_hooks_; + bool precision_check_registered_ = false; + private: std::unordered_map> NamedModules(const std::string &prefix = "", bool remove_duplicate = true, diff --git a/infini_train/include/nn/modules/normalization.h b/infini_train/include/nn/modules/normalization.h index 4dcdf807..111e96b7 100644 --- a/infini_train/include/nn/modules/normalization.h +++ b/infini_train/include/nn/modules/normalization.h @@ -13,6 +13,7 @@ class Device; namespace infini_train::nn { class LayerNorm : public CloneableModule { public: + static constexpr char kType[] = "LayerNorm"; static constexpr char kParamWeightName[] = "weight"; static constexpr char kParamBiasName[] = "bias"; diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 480c1286..bc178c19 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -27,7 +27,13 @@ class GlobalEnv { static GlobalEnv &Instance(); void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size, int virtual_pipeline_parallel_size); + int pipeline_parallel_size, int virtual_pipeline_parallel_size, int precision_check_level = 0, + bool precision_check_all_ranks = false); + + enum class PrecisionCheckLevel { NONE, FUNCTION, MODULE }; + void SetPrecisionCheckLevel(PrecisionCheckLevel level); + PrecisionCheckLevel GetPrecisionCheckLevel() const; + bool GetPrecisionCheckAllRanks() const; int nnodes() const; @@ -83,14 +89,17 @@ class GlobalEnv { bool initialized_ = false; Layout layout_; + PrecisionCheckLevel precision_check_level_ = PrecisionCheckLevel::NONE; + bool precision_check_all_ranks_ = false; }; inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size, int virtual_pipeline_parallel) { + int pipeline_parallel_size, int virtual_pipeline_parallel, int precision_check_level = 0, + bool precision_check_all_ranks = false) { GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, - pipeline_parallel_size, virtual_pipeline_parallel); + pipeline_parallel_size, virtual_pipeline_parallel, precision_check_level, + precision_check_all_ranks); } - inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); } diff --git a/infini_train/include/nn/parallel/tensor_parallel.h b/infini_train/include/nn/parallel/tensor_parallel.h index a2aa61ea..3a6f6498 100644 --- a/infini_train/include/nn/parallel/tensor_parallel.h +++ b/infini_train/include/nn/parallel/tensor_parallel.h @@ -103,8 +103,9 @@ class VocabParallelCrossEntropy : public autograd::Function { class VocabParallelCrossEntropyLoss : public nn::CloneableModule { public: + static constexpr char kType[] = "VocabParallelCrossEntropyLoss"; VocabParallelCrossEntropyLoss(int64_t vocab_size_original = 0, float label_smoothing = 0.f) - : vocab_size_original_(vocab_size_original), label_smoothing_(label_smoothing){}; + : CloneableModule(kType), vocab_size_original_(vocab_size_original), label_smoothing_(label_smoothing){}; std::vector> Forward(const std::vector> &input_tensors) override; diff --git a/infini_train/include/utils/precision_checker.h b/infini_train/include/utils/precision_checker.h new file mode 100644 index 00000000..6c09202b --- /dev/null +++ b/infini_train/include/utils/precision_checker.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include + +namespace infini_train { +class Tensor; + +namespace autograd { +class Function; +class HookHandle; +} // namespace autograd + +namespace nn { +class Module; +} // namespace nn + +namespace utils { + +class PrecisionChecker { +public: + struct Config { + bool check_nan = true; + bool check_inf = true; + bool print_stats = true; + bool abort_on_error = false; + }; + + static const Config& DefaultConfig() { + static Config default_config; + return default_config; + } + + static void RegisterForFunction(autograd::Function* func, const std::string& name = "", + const Config& config = DefaultConfig()); + + // Register hooks for a Module (checks forward inputs/outputs) + static void RegisterForModule(nn::Module* module, const std::string& name = "", + const Config& config = DefaultConfig()); + +private: + static void CheckTensors(const std::string& stage, const std::string& name, + const std::vector>& tensors, + const Config& config); +}; + +} // namespace utils +} // namespace infini_train diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 48ad02a9..de39302a 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -3,10 +3,13 @@ #include "glog/logging.h" #include "infini_train/include/autograd/accumulate.h" +#include "infini_train/include/autograd/function_hook.h" #include "infini_train/include/autograd/grad_mode.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/tensor.h" +#include "infini_train/include/utils/precision_checker.h" namespace infini_train::autograd { @@ -16,6 +19,22 @@ std::vector> Function::Apply(const std::vectorSetDevice(); + // Register precision check hooks if enabled (before forward) + if (!precision_check_registered_) { + auto precision_level = nn::parallel::global::GlobalEnv::Instance().GetPrecisionCheckLevel(); + if (precision_level == nn::parallel::global::GlobalEnv::PrecisionCheckLevel::FUNCTION) { + utils::PrecisionChecker::RegisterForFunction(this, type_); + precision_check_registered_ = true; + } + } + + // Call forward pre-hooks + for (const auto& hook : forward_pre_hooks_) { + if (hook) { + hook(this, input_tensors); + } + } + std::vector> output_tensors; { autograd::NoGradGuard no_grad; @@ -24,6 +43,13 @@ std::vector> Function::Apply(const std::vector &grad_output, int g ++dependencies_reached_; if (grad_outputs_reached_ == grad_outputs_.size() && (dependencies_reached_ == dependencies_number_ || dependencies_number_ == 0)) { + + // Call backward pre-hooks + for (const auto& hook : backward_pre_hooks_) { + if (hook) { + hook(this, grad_outputs_); + } + } + std::vector> grad_inputs; { autograd::NoGradGuard no_grad; // no_grad in autograd.Function.Backward() grad_inputs = Backward(grad_outputs_); } + + // Call backward post-hooks + for (const auto& hook : backward_post_hooks_) { + if (hook) { + hook(this, grad_inputs, grad_outputs_); + } + } + saved_tensors_.clear(); grad_outputs_.clear(); grad_outputs_reached_ = 0; @@ -94,6 +136,23 @@ void Function::BackwardPartial(const std::shared_ptr &grad_output, int g auto &grad_input = grad_inputs[idx]; auto &[next_function, output_idx] = next_functions_[idx]; if (grad_input && next_function) { + // // Apply tensor backward hooks only for leaf tensors + // // Only AccumulateGrad corresponds to a leaf tensor that user can register hooks on + // auto accumulate_grad = std::dynamic_pointer_cast(next_function); + // if (accumulate_grad) { + // auto tensor = accumulate_grad->tensor(); + // if (tensor) { + // const auto& hooks = tensor->backward_post_hooks_(); + // for (const auto& hook : hooks) { + // if (hook) { + // auto modified_grad = hook(grad_input); + // if (modified_grad) { + // grad_input = modified_grad; + // } + // } + // } + // } + // } next_function->BackwardPartial(grad_input, output_idx); } } @@ -101,4 +160,24 @@ void Function::BackwardPartial(const std::shared_ptr &grad_output, int g } void Function::IncreaseDependenciesNumber() { ++dependencies_number_; } + +std::shared_ptr Function::RegisterForwardPreHook(FunctionForwardPreHook hook) { + forward_pre_hooks_.push_back(std::move(hook)); + return std::make_shared>(&forward_pre_hooks_, forward_pre_hooks_.size() - 1); +} + +std::shared_ptr Function::RegisterForwardPostHook(FunctionForwardPostHook hook) { + forward_post_hooks_.push_back(std::move(hook)); + return std::make_shared>(&forward_post_hooks_, forward_post_hooks_.size() - 1); +} + +std::shared_ptr Function::RegisterBackwardPreHook(FunctionBackwardPreHook hook) { + backward_pre_hooks_.push_back(std::move(hook)); + return std::make_shared>(&backward_pre_hooks_, backward_pre_hooks_.size() - 1); +} + +std::shared_ptr Function::RegisterBackwardPostHook(FunctionBackwardPostHook hook) { + backward_post_hooks_.push_back(std::move(hook)); + return std::make_shared>(&backward_post_hooks_, backward_post_hooks_.size() - 1); +} } // namespace infini_train::autograd diff --git a/infini_train/src/nn/modules/container.cc b/infini_train/src/nn/modules/container.cc index 6df46663..33707b63 100644 --- a/infini_train/src/nn/modules/container.cc +++ b/infini_train/src/nn/modules/container.cc @@ -7,7 +7,7 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn { -Sequential::Sequential(std::vector> &&layers) { +Sequential::Sequential(std::vector> &&layers) : CloneableModule(kType) { int idx = 0; for (auto &layer : layers) { modules_[std::to_string(idx)] = std::move(layer); @@ -17,11 +17,11 @@ Sequential::Sequential(std::vector> &&layers) { std::vector> Sequential::Forward(const std::vector> &input_tensors) { auto &x = const_cast> &>(input_tensors); - for (int idx = 0; idx < modules_.size(); ++idx) { x = modules_[std::to_string(idx)]->Forward(x); } + for (int idx = 0; idx < modules_.size(); ++idx) { x = (*modules_[std::to_string(idx)])(x); } return x; } -ModuleDict::ModuleDict(std::unordered_map> modules) { +ModuleDict::ModuleDict(std::unordered_map> modules) : CloneableModule(kType) { for (auto &[name, layer] : modules) { modules_[name] = std::move(layer); } } diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 4e0c6a28..b80a37ab 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -7,8 +7,11 @@ #include "glog/logging.h" +#include "infini_train/include/autograd/function.h" #include "infini_train/include/device.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/tensor.h" +#include "infini_train/include/utils/precision_checker.h" namespace infini_train::nn { @@ -125,6 +128,61 @@ std::vector> Module::Forward(const std::vector> Module::operator()(const std::vector> &input_tensors) { + // Register precision check hooks if enabled and not already registered + if (!precision_check_registered_) { + auto precision_level = parallel::global::GlobalEnv::Instance().GetPrecisionCheckLevel(); + if (precision_level == parallel::global::GlobalEnv::PrecisionCheckLevel::MODULE) { + utils::PrecisionChecker::RegisterForModule(this); + precision_check_registered_ = true; + } + } + + // Call forward pre-hooks + for (const auto& hook : forward_pre_hooks_) { + if (hook) { + hook(this, input_tensors); + } + } + + // Call actual Forward implementation + auto output_tensors = Forward(input_tensors); + + // Call forward post-hooks + for (const auto& hook : forward_post_hooks_) { + if (hook) { + hook(this, input_tensors, output_tensors); + } + } + + // Register backward hooks on output tensors' grad_fn + if (!backward_pre_hooks_.empty() || !backward_post_hooks_.empty()) { + for (const auto& output : output_tensors) { + if (output && output->grad_fn()) { + if (!backward_pre_hooks_.empty()) { + output->grad_fn()->RegisterBackwardPreHook( + [this](autograd::Function*, const std::vector>& grad_outputs) { + for (const auto& hook : backward_pre_hooks_) { + if (hook) hook(this, grad_outputs); + } + }); + } + if (!backward_post_hooks_.empty()) { + output->grad_fn()->RegisterBackwardPostHook( + [this](autograd::Function*, const std::vector>& grad_inputs, + const std::vector>& grad_outputs) { + for (const auto& hook : backward_post_hooks_) { + if (hook) hook(this, grad_inputs, grad_outputs); + } + }); + } + } + } + } + + return output_tensors; +} + void Module::To(const Device *device) { CHECK_NOTNULL(device); if (device == device_) { @@ -166,4 +224,24 @@ std::shared_ptr Module::ReplicateForDataParallel(int device_idx) const { // TODO(dcj): use device_idx later return std::make_shared(*this); } + +std::shared_ptr Module::RegisterForwardPreHook(ForwardPreHook hook) { + forward_pre_hooks_.push_back(std::move(hook)); + return std::make_shared>(&forward_pre_hooks_, forward_pre_hooks_.size() - 1); +} + +std::shared_ptr Module::RegisterForwardPostHook(ForwardPostHook hook) { + forward_post_hooks_.push_back(std::move(hook)); + return std::make_shared>(&forward_post_hooks_, forward_post_hooks_.size() - 1); +} + +std::shared_ptr Module::RegisterBackwardPreHook(BackwardPreHook hook) { + backward_pre_hooks_.push_back(std::move(hook)); + return std::make_shared>(&backward_pre_hooks_, backward_pre_hooks_.size() - 1); +} + +std::shared_ptr Module::RegisterBackwardPostHook(BackwardPostHook hook) { + backward_post_hooks_.push_back(std::move(hook)); + return std::make_shared>(&backward_post_hooks_, backward_post_hooks_.size() - 1); +} } // namespace infini_train::nn diff --git a/infini_train/src/nn/modules/normalization.cc b/infini_train/src/nn/modules/normalization.cc index 5479fc72..4c7c68f1 100644 --- a/infini_train/src/nn/modules/normalization.cc +++ b/infini_train/src/nn/modules/normalization.cc @@ -9,7 +9,8 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn { -LayerNorm::LayerNorm(const std::vector &normalized_shape, float eps, const Device *device) : eps_(eps) { +LayerNorm::LayerNorm(const std::vector &normalized_shape, float eps, const Device *device) + : CloneableModule(kType), eps_(eps) { device_ = device ? device : DeviceManager::Instance()->GetDefaultDevice(); parameters_[kParamWeightName] diff --git a/infini_train/src/nn/parallel/data_parallel.cc b/infini_train/src/nn/parallel/data_parallel.cc index 0dec0c8f..1a64ab8a 100644 --- a/infini_train/src/nn/parallel/data_parallel.cc +++ b/infini_train/src/nn/parallel/data_parallel.cc @@ -31,7 +31,7 @@ ParallelApply(const std::vector> &modules, auto worker = [&](const std::shared_ptr &module, const std::vector> &inputs, const Device *device, int idx) { device->SetDevice(); - auto output = module->Forward(inputs); + auto output = (*module)(inputs); results[idx] = output; }; @@ -86,7 +86,7 @@ std::vector> DataParallel::Forward(const std::vectorForward(scattered_inputs[0]); + return (*module)(scattered_inputs[0]); } auto replicas = function::Replicate(module, devices_); diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index a25a7d16..aacbfa38 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -50,7 +50,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod std::vector> DistributedDataParallel::Forward(const std::vector> &input_tensors) { - auto outputs = modules_[kModuleName]->Forward(input_tensors); + auto outputs = (*modules_[kModuleName])(input_tensors); if (reducer_) { reducer_->PrepareForBackward(); } diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 39cd95dd..80a01f57 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/global.h" +#include #include #include #include @@ -90,7 +91,8 @@ GlobalEnv &GlobalEnv::Instance() { } void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size, int virtual_pipeline_parallel_size) { + int pipeline_parallel_size, int virtual_pipeline_parallel_size, int precision_check_level, + bool precision_check_all_ranks) { std::lock_guard lock(mutex_); CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; @@ -114,6 +116,16 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq layout_.sizes[PP] = pipeline_parallel_size_; layout_.InitStrides(); + // Initialize precision check level from parameter + if (precision_check_level == 1) { + precision_check_level_ = PrecisionCheckLevel::MODULE; + } else if (precision_check_level == 2) { + precision_check_level_ = PrecisionCheckLevel::FUNCTION; + } else { + precision_check_level_ = PrecisionCheckLevel::NONE; + } + precision_check_all_ranks_ = precision_check_all_ranks; + initialized_ = true; } @@ -182,6 +194,19 @@ Layout GlobalEnv::layout() const { return layout_; } +void GlobalEnv::SetPrecisionCheckLevel(PrecisionCheckLevel level) { + precision_check_level_ = level; +} + +GlobalEnv::PrecisionCheckLevel GlobalEnv::GetPrecisionCheckLevel() const { + return precision_check_level_; +} + +bool GlobalEnv::GetPrecisionCheckAllRanks() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return precision_check_all_ranks_; +} + namespace { inline const char *AxisName(Axis a) { return a == DP ? "DP" : (a == TP ? "TP" : "PP"); } diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index 95dd3bbc..5eb32488 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -244,7 +244,7 @@ float PipelineSchedule::StepMicroBatches(const std::vectordevice()->Type(), dtype); auto target_on_device = target->To(activations[task.local_chunk_idx][mb][0]->GetDevice()); - loss = loss_fn->Forward( + loss = (*loss_fn)( {activations[task.local_chunk_idx][mb][0], std::make_shared(target_on_device)})[0]; loss = loss / n; } diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index 582b9bd2..968bb0a6 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -25,7 +25,7 @@ std::vector> PipelineStage::ForwardOneChunk(const std::v LOG(FATAL) << "PipelineStage::ForwardOneChunk: local_chunk_idx=" << local_chunk_idx << " out of range [0, " << chunks_.size() << ")"; } - return chunks_[local_chunk_idx]->Forward(inputs); + return (*chunks_[local_chunk_idx])(inputs); } bool PipelineStage::IsFirstStage() const { return stage_index_ == 0; } diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index b91028df..129e2f4b 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -269,8 +269,8 @@ std::vector> GatherFromSPRegionFunc(const std::shared_pt ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, int64_t out_features, bool bias, bool gather_output, bool input_is_parallel, bool skip_bias_add, bool sequence_parallel) - : bias_(bias), gather_output_(gather_output), input_is_parallel_(input_is_parallel), skip_bias_add_(skip_bias_add), - sequence_parallel_(sequence_parallel) { + : CloneableModule(kType), bias_(bias), gather_output_(gather_output), input_is_parallel_(input_is_parallel), + skip_bias_add_(skip_bias_add), sequence_parallel_(sequence_parallel) { auto tp_size = global::GetTensorParallelSize(); CHECK_GT(tp_size, 0) << "No available devices found"; CHECK_EQ(out_features % tp_size, 0) << "out_features must be divisible by TP world size for ColumnParallel"; @@ -315,8 +315,8 @@ ColumnParallelLinear::Forward(const std::vector> &input_ RowParallelLinear::RowParallelLinear(int64_t in_features, int64_t out_features, bool bias, bool reduce_output, bool input_is_parallel, bool skip_bias_add, bool sequence_parallel) - : bias_(bias), reduce_output_(reduce_output), input_is_parallel_(input_is_parallel), skip_bias_add_(skip_bias_add), - sequence_parallel_(sequence_parallel) { + : CloneableModule(kType), bias_(bias), reduce_output_(reduce_output), input_is_parallel_(input_is_parallel), + skip_bias_add_(skip_bias_add), sequence_parallel_(sequence_parallel) { auto tp_size = global::GetTensorParallelSize(); CHECK_GT(tp_size, 0) << "No available devices found"; CHECK_EQ(in_features % tp_size, 0) << "in_features must be divisible by TP world size for RowParallel"; @@ -362,7 +362,7 @@ RowParallelLinear::Forward(const std::vector> &input_ten VocabParallelEmbedding::VocabParallelEmbedding(int64_t num_embeddings, int64_t embedding_dim, bool reduce_scatter_embeddings) - : vocab_size_global_(num_embeddings), embedding_dim_(embedding_dim), + : CloneableModule(kType), vocab_size_global_(num_embeddings), embedding_dim_(embedding_dim), reduce_scatter_embeddings_(reduce_scatter_embeddings) { auto tp_size = global::GetTensorParallelSize(); CHECK_GT(tp_size, 0) << "No available devices found for VocabParallelEmbedding"; diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc new file mode 100644 index 00000000..0f040df9 --- /dev/null +++ b/infini_train/src/utils/precision_checker.cc @@ -0,0 +1,184 @@ +#include "infini_train/include/utils/precision_checker.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "infini_train/include/autograd/function.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::utils { + +namespace { +std::ofstream& GetLogStream() { + static std::ofstream log_file; + static std::mutex init_mutex; + static bool initialized = false; + + if (!initialized) { + std::lock_guard lock(init_mutex); + if (!initialized) { + int rank = nn::parallel::global::GlobalEnv::Instance().global_proc_rank(); + std::string filename = "precision_check_rank_" + std::to_string(rank) + ".log"; + log_file.open(filename, std::ios::out | std::ios::trunc); + initialized = true; + } + } + return log_file; +} + +bool ShouldPrint() { + if (nn::parallel::global::GlobalEnv::Instance().GetPrecisionCheckAllRanks()) { + return true; + } + return nn::parallel::global::GlobalEnv::Instance().global_proc_rank() == 0; +} + +std::string GetTimestamp() { + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + auto ms = std::chrono::duration_cast(now.time_since_epoch()) % 1000; + + std::tm tm; + localtime_r(&time_t, &tm); + + std::ostringstream oss; + oss << std::setfill('0') + << std::setw(2) << (tm.tm_mon + 1) + << std::setw(2) << tm.tm_mday << ' ' + << std::setw(2) << tm.tm_hour << ':' + << std::setw(2) << tm.tm_min << ':' + << std::setw(2) << tm.tm_sec << '.' + << std::setw(3) << ms.count(); + return oss.str(); +} +} // namespace + +void PrecisionChecker::CheckTensors(const std::string& stage, const std::string& name, + const std::vector>& tensors, + const Config& config) { + if (!ShouldPrint()) { + return; + } + + int rank = nn::parallel::global::GlobalEnv::Instance().global_proc_rank(); + + for (size_t i = 0; i < tensors.size(); ++i) { + if (!tensors[i]) continue; + + auto& tensor = tensors[i]; + + // Copy tensor to CPU if it's on GPU + std::shared_ptr cpu_tensor; + if (tensor->GetDevice()->Type() == DeviceType::kCUDA) { + auto cpu_device = DeviceManager::Instance()->GetDevice(DeviceType::kCPU); + cpu_tensor = std::make_shared(tensor->To(cpu_device)); + } else { + cpu_tensor = tensor; + } + + const float* data = static_cast(cpu_tensor->DataPtr()); + size_t size = cpu_tensor->NumElements(); + + bool has_nan = false; + bool has_inf = false; + + for (size_t j = 0; j < size; ++j) { + float val = data[j]; + if (std::isnan(val)) has_nan = true; + if (std::isinf(val)) has_inf = true; + } + + bool has_error = (config.check_nan && has_nan) || (config.check_inf && has_inf); + + if (has_error || config.print_stats) { + auto& log_stream = GetLogStream(); + std::string level = has_error ? "E" : "I"; + + log_stream << level << GetTimestamp() << " [Rank " << rank << "][PrecisionCheck] " + << stage << " " << name << " tensor[" << i << "]: ["; + + if (has_nan) log_stream << " NaN detected!"; + if (has_inf) log_stream << " Inf detected!"; + + if (config.print_stats) { + constexpr size_t max_print = 10; + for (size_t j = 0; j < std::min(size, max_print); ++j) { + if (j > 0) log_stream << ", "; + log_stream << data[j]; + } + if (size > max_print) log_stream << ", ..."; + } + log_stream << "]" << std::endl; + } + + if (has_error && config.abort_on_error) { + std::cerr << "Precision check failed, aborting!" << std::endl; + std::abort(); + } + } +} + +void PrecisionChecker::RegisterForFunction(autograd::Function* func, const std::string& name, + const Config& config) { + std::string func_name = name.empty() ? "Function" : name; + + func->RegisterForwardPreHook([func_name, config](autograd::Function*, + const std::vector>& inputs) { + CheckTensors("Forward Input", func_name, inputs, config); + }); + + func->RegisterForwardPostHook([func_name, config](autograd::Function*, + const std::vector>&, + const std::vector>& outputs) { + CheckTensors("Forward Output", func_name, outputs, config); + }); + + func->RegisterBackwardPreHook([func_name, config](autograd::Function*, + const std::vector>& grad_outputs) { + CheckTensors("Backward Input", func_name, grad_outputs, config); + }); + + func->RegisterBackwardPostHook([func_name, config](autograd::Function*, + const std::vector>& grad_inputs, + const std::vector>&) { + CheckTensors("Backward Output", func_name, grad_inputs, config); + }); +} + +void PrecisionChecker::RegisterForModule(nn::Module* module, const std::string& name, + const Config& config) { + std::string module_name = name.empty() ? module->type() : name; + + // module->RegisterForwardPreHook([module_name, config](nn::Module*, + // const std::vector>& inputs) { + // CheckTensors("Module Forward Input", module_name, inputs, config); + // }); + + module->RegisterForwardPostHook([module_name, config](nn::Module*, + const std::vector>&, + const std::vector>& outputs) { + CheckTensors("Module Forward Output", module_name, outputs, config); + }); + + // module->RegisterBackwardPreHook([module_name, config](nn::Module*, + // const std::vector>& grad_outputs) { + // CheckTensors("Module Backward Input", module_name, grad_outputs, config); + // }); + + module->RegisterBackwardPostHook([module_name, config](nn::Module*, + const std::vector>& grad_inputs, + const std::vector>&) { + CheckTensors("Module Backward Output", module_name, grad_inputs, config); + }); +} + +} // namespace infini_train::utils