Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
DEFINE_int32(precision_check, 0, "precision check level: 0=off, 1=module, 2=function");
DEFINE_bool(precision_check_all_ranks, false, "enable precision check for all ranks (default: rank 0 only)");

using namespace infini_train;

Expand Down Expand Up @@ -297,9 +300,9 @@ void Train(const nn::parallel::Rank &rank) {

LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";
// (bs, seq_len, vocab_size)
auto logits = model->Forward({x, y})[0];
auto logits = (*model)({x, y})[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
auto loss = (*loss_fn)({logits, y})[0];
// FIXME(jym): verify gradient accumulation precision
loss = loss / grad_accum_steps;

Expand Down Expand Up @@ -364,7 +367,8 @@ int main(int argc, char *argv[]) {
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel, FLAGS_precision_check,
FLAGS_precision_check_all_ranks);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
44 changes: 22 additions & 22 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ NewGELU::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
}

CausalSelfAttention::CausalSelfAttention(const GPT2Config &config)
: config_(config), n_head_(config.n_head), n_embd_(config.n_embd) {
: CloneableModule(kType), config_(config), n_head_(config.n_head), n_embd_(config.n_embd) {
auto tp_world_size = nn::parallel::global::GetTensorParallelSize();
CHECK_EQ(config.n_embd % config.n_head, 0);
CHECK_EQ(n_head_ % tp_world_size, 0) << "n_head must be divisible by TP world size";
Expand Down Expand Up @@ -89,7 +89,7 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten

// (B, T, C) -> ColumnParallelLinear(C, 3*C) -> (B, T, 3 * local_C)
// -> Split -> (3, B, T, local_C)
auto qkv = modules_[kCAttnLayerName]->Forward(x)[0]->Split(local_C, 2);
auto qkv = (*modules_[kCAttnLayerName])(x)[0]->Split(local_C, 2);

// (B, T, local_C)
auto q = qkv[0];
Expand Down Expand Up @@ -120,12 +120,12 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
y = modules_[kCProjLayerName]->Forward({y})[0];
y = (*modules_[kCProjLayerName])({y})[0];
// (B, T, C) == (bs, seq_len, n_embd)
return {y};
}

MLP::MLP(const GPT2Config &config) {
MLP::MLP(const GPT2Config &config) : CloneableModule(kType) {
// c_fc: ColumnParallel (input full, output parallel)
modules_[kCFcLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
/*in_features=*/config.n_embd, /*out_features=*/4 * config.n_embd,
Expand All @@ -150,16 +150,16 @@ MLP::MLP(const GPT2Config &config) {
std::vector<std::shared_ptr<infini_train::Tensor>>
MLP::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
// (B, T, C) -> ColumnParallelLinear(C, 4 * C) -> (B, T, 4 * C_local)
auto x1 = modules_[kCFcLayerName]->Forward(x);
auto x1 = (*modules_[kCFcLayerName])(x);
// (B, T, 4 * C_local) -> GELU -> (B, T, 4 * C_local)
auto x2 = modules_[kGeluLayerName]->Forward(x1);
auto x2 = (*modules_[kGeluLayerName])(x1);
// (B, T, 4 * C_local) -> RowParallelLinear(4 * C, C) -> (B, T, C)
auto x3 = modules_[kCProjLayerName]->Forward(x2);
auto x3 = (*modules_[kCProjLayerName])(x2);
// (B, T, C)
return x3;
}

Block::Block(const GPT2Config &config) {
Block::Block(const GPT2Config &config) : CloneableModule(kType) {
modules_[kLn1LayerName] = std::make_shared<nn::LayerNorm>(std::vector<int64_t>{config.n_embd});
modules_[kAttnLayerName] = std::make_shared<CausalSelfAttention>(config);
modules_[kLn2LayerName] = std::make_shared<nn::LayerNorm>(std::vector<int64_t>{config.n_embd});
Expand All @@ -170,15 +170,15 @@ std::vector<std::shared_ptr<infini_train::Tensor>>
Block::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
// (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> attention -> (bs, seq_len, n_embd)
// -> Add -> (bs, seq_len, n_embd)
auto x1 = x[0] + modules_[kAttnLayerName]->Forward(modules_[kLn1LayerName]->Forward(x))[0];
auto x1 = x[0] + (*modules_[kAttnLayerName])((*modules_[kLn1LayerName])(x))[0];
// (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd)
// -> Add -> (bs, seq_len, n_embd)
auto x2 = x1 + modules_[kMlpLayerName]->Forward(modules_[kLn2LayerName]->Forward({x1}))[0];
auto x2 = x1 + (*modules_[kMlpLayerName])((*modules_[kLn2LayerName])({x1}))[0];
// (bs, seq_len, n_embd)
return {x2};
}

GPT2FirstStage::GPT2FirstStage(const GPT2Config &config) : config_(config) {
GPT2FirstStage::GPT2FirstStage(const GPT2Config &config) : CloneableModule(kType), config_(config) {
modules_[kWTELayerName] = std::make_shared<nn::parallel::VocabParallelEmbedding>(
config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled());
modules_[kWPELayerName] = std::make_shared<nn::Embedding>(config_.block_size, config_.n_embd);
Expand Down Expand Up @@ -207,15 +207,15 @@ GPT2FirstStage::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>>
auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device);

// (B, T) -> Embedding(V_local, C) -> (B, T, C)
auto tok_emb = modules_[kWTELayerName]->Forward({x1})[0];
auto tok_emb = (*modules_[kWTELayerName])({x1})[0];

// (T) -> Embedding(T_max, C) -> (T, C)
auto pos_emb = modules_[kWPELayerName]->Forward({pos})[0];
auto pos_emb = (*modules_[kWPELayerName])({pos})[0];
// (B, T, C)
return {tok_emb + pos_emb};
}

GPT2Chunk::GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer) : config_(config) {
GPT2Chunk::GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer) : CloneableModule(kType), config_(config) {
std::vector<std::shared_ptr<nn::Module>> h;
for (int64_t i = start_layer; i < end_layer; ++i) {
auto layer = std::make_shared<Block>(config);
Expand All @@ -228,11 +228,11 @@ std::vector<std::shared_ptr<infini_train::Tensor>>
GPT2Chunk::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
auto x1 = x[0];
// (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd)
for (auto &h : *std::dynamic_pointer_cast<nn::ModuleList>(modules_[kHLayerName])) { x1 = h->Forward({x1})[0]; }
for (auto &h : *std::dynamic_pointer_cast<nn::ModuleList>(modules_[kHLayerName])) { x1 = (*h)({x1})[0]; }
return {x1};
}

GPT2LastStage::GPT2LastStage(const GPT2Config &config) : config_(config) {
GPT2LastStage::GPT2LastStage(const GPT2Config &config) : CloneableModule(kType), config_(config) {
modules_[kLnFLayerName] = std::make_shared<nn::LayerNorm>(std::vector<int64_t>{config_.n_embd});
// don't init this one, we will tie weights
modules_[kLMHeadLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
Expand All @@ -248,15 +248,15 @@ GPT2LastStage::GPT2LastStage(const GPT2Config &config) : config_(config) {
std::vector<std::shared_ptr<infini_train::Tensor>>
GPT2LastStage::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
// (B, T, C) -> Layernorm -> (B, T, C)
auto x1 = modules_[kLnFLayerName]->Forward(x);
auto x1 = (*modules_[kLnFLayerName])(x);

// TODO(dcj): add inference-time mini-optimization
// (B, T, C) -> Linear(C, V) -> (B, T, V)
return modules_[kLMHeadLayerName]->Forward(x1);
return (*modules_[kLMHeadLayerName])(x1);
}

GPT2::GPT2(const GPT2Config &config)
: config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo(
: CloneableModule(kType), config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo(
config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank,
nn::parallel::global::GetVirtualPipelineParallelSize())) {
auto tp_world_size = nn::parallel::global::GetTensorParallelSize();
Expand Down Expand Up @@ -316,11 +316,11 @@ GPT2::GPT2(const GPT2Config &config)

std::vector<std::shared_ptr<infini_train::Tensor>>
GPT2::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
auto x1 = modules_[kPPFirstStageName]->Forward(x);
auto x1 = (*modules_[kPPFirstStageName])(x);
for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) {
x1 = modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)]->Forward(x1);
x1 = (*modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)])(x1);
}
return modules_[kPPLastStageName]->Forward(x1);
return (*modules_[kPPLastStageName])(x1);
}

std::shared_ptr<GPT2> GPT2::FromPretrained(ModelType model_type) {
Expand Down
10 changes: 10 additions & 0 deletions example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ struct GPT2Config {

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
public:
static constexpr char kType[] = "NewGELU";
NewGELU() : CloneableModule(kType) {}

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
};

class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfAttention> {
public:
static constexpr char kType[] = "CausalSelfAttention";
static constexpr char kCAttnLayerName[] = "c_attn";
static constexpr char kCProjLayerName[] = "c_proj";

Expand All @@ -49,6 +53,7 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA

class MLP : public infini_train::nn::CloneableModule<MLP> {
public:
static constexpr char kType[] = "MLP";
static constexpr char kCFcLayerName[] = "c_fc";
static constexpr char kGeluLayerName[] = "gelu";
static constexpr char kCProjLayerName[] = "c_proj";
Expand All @@ -61,6 +66,7 @@ class MLP : public infini_train::nn::CloneableModule<MLP> {

class Block : public infini_train::nn::CloneableModule<Block> {
public:
static constexpr char kType[] = "Block";
static constexpr char kLn1LayerName[] = "ln_1";
static constexpr char kAttnLayerName[] = "attn";
static constexpr char kLn2LayerName[] = "ln_2";
Expand All @@ -74,6 +80,7 @@ class Block : public infini_train::nn::CloneableModule<Block> {

class GPT2FirstStage : public infini_train::nn::CloneableModule<GPT2FirstStage> {
public:
static constexpr char kType[] = "GPT2FirstStage";
static constexpr char kWTELayerName[] = "wte";
static constexpr char kWPELayerName[] = "wpe";

Expand All @@ -88,6 +95,7 @@ class GPT2FirstStage : public infini_train::nn::CloneableModule<GPT2FirstStage>

class GPT2Chunk : public infini_train::nn::CloneableModule<GPT2Chunk> {
public:
static constexpr char kType[] = "GPT2Chunk";
static constexpr char kHLayerName[] = "h";

GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer);
Expand All @@ -101,6 +109,7 @@ class GPT2Chunk : public infini_train::nn::CloneableModule<GPT2Chunk> {

class GPT2LastStage : public infini_train::nn::CloneableModule<GPT2LastStage> {
public:
static constexpr char kType[] = "GPT2LastStage";
static constexpr char kLnFLayerName[] = "ln_f";
static constexpr char kLMHeadLayerName[] = "lm_head";

Expand All @@ -115,6 +124,7 @@ class GPT2LastStage : public infini_train::nn::CloneableModule<GPT2LastStage> {

class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
public:
static constexpr char kType[] = "GPT2";
static constexpr char kTransformerLayerName[] = "transformer";

enum class ModelType : int8_t {
Expand Down
10 changes: 7 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
DEFINE_int32(precision_check, 0, "precision check level: 0=off, 1=module, 2=function");
DEFINE_bool(precision_check_all_ranks, false, "enable precision check for all ranks (default: rank 0 only)");

using namespace infini_train;

Expand Down Expand Up @@ -273,9 +276,9 @@ void Train(const nn::parallel::Rank &rank) {

LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";
// (bs, seq_len, vocab_size)
auto logits = model->Forward({x, y})[0];
auto logits = (*model)({x, y})[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
auto loss = (*loss_fn)({logits, y})[0];
// FIXME(jym): verify gradient accumulation precision
loss = loss / grad_accum_steps;

Expand Down Expand Up @@ -340,7 +343,8 @@ int main(int argc, char *argv[]) {
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel, FLAGS_precision_check,
FLAGS_precision_check_all_ranks);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
Loading
Loading