Skip to content
This repository was archived by the owner on Oct 14, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions apex/apex/transformer/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
init_method=init.xavier_normal_,
*,
params_dtype: torch.dtype=torch.float32,
resume_from_checkpoint: bool = False,
use_cpu_initialization: bool = False,
):
super().__init__()
Expand Down Expand Up @@ -224,8 +225,22 @@ def __init__(
self.vocab_end_index - self.vocab_start_index
)

# Allocate weights and initialize.
if use_cpu_initialization:
# Allocate weights and initialize if necessary
if resume_from_checkpoint:
# resume from checkpoint does not need initialize, so always allocate memory on device
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype,
requires_grad=True
)
)
set_tensor_model_parallel_attributes(
tensor=self.weight, is_parallel=True, dim=0, stride=1
)
elif use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
Expand Down Expand Up @@ -490,6 +505,7 @@ class ColumnParallelLinear(torch.nn.Module):
Keyword Arguments:
no_async_tensor_model_parallel_allreduce:
params_dtype:
resume_from_checkpoint: if True, do not initialize (so use_cpu_initialization will have no effect)
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
Expand All @@ -509,6 +525,7 @@ def __init__(
*,
no_async_tensor_model_parallel_allreduce=False,
params_dtype=torch.float32,
resume_from_checkpoint=False,
use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
Expand All @@ -530,8 +547,21 @@ def __init__(
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
# Initialize weight if necessary
if resume_from_checkpoint:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
requires_grad=True
)
)
set_tensor_model_parallel_attributes(
tensor=self.weight, is_parallel=True, dim=0, stride=stride
)
elif use_cpu_initialization:
self.weight = Parameter(
torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype)
)
Expand Down Expand Up @@ -682,6 +712,7 @@ class RowParallelLinear(torch.nn.Module):
adding bias but instead return it.
Keyword Arguments:
params_dtype:
resume_from_checkpoint: if True, do not initialize, only allocate memory on device (hence use_cpu_initialization will have no effect)
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
Expand All @@ -700,6 +731,7 @@ def __init__(
skip_bias_add=False,
*,
params_dtype=torch.float32,
resume_from_checkpoint=False,
use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
Expand All @@ -726,8 +758,21 @@ def __init__(
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
# Initialize weight if necessary
if resume_from_checkpoint:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
requires_grad=True
)
)
set_tensor_model_parallel_attributes(
tensor=self.weight, is_parallel=True, dim=1, stride=stride
)
elif use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=params_dtype
Expand Down
2 changes: 2 additions & 0 deletions k8s/example_manifests/mpi_train_gpt23b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ spec:
slotsPerWorker: 1
runPolicy:
cleanPodPolicy: Running
# Uncomment the line below to allow the launcher to wait longer for the workers to join
# backoffLimit: 20
mpiReplicaSpecs:
Launcher:
replicas: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ def convert_checkpoint(p, args):
output_folder = output_folder + f"_pp_rank_{p:03d}"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
torch.save(out_model, f"{output_folder}/model_optim_rng.ckpt") #, (not master_only), global_master=True)
if args.is_xser:
from nemo.collections.nlp.parts.serialization import save
save(out_model, f"{output_folder}/model_optim_rng.ckpt")
else:
torch.save(out_model, f"{output_folder}/model_optim_rng.ckpt") #, (not master_only), global_master=True)
print("Done saving Megatron checkpoint")


Expand Down Expand Up @@ -284,11 +288,16 @@ def convert_checkpoint(p, args):
type=bool,
help="To use bias in the model layers",
)
parser.add_argument(
"--is_xser",
action="store_true",
help="Enable serialized saving",
)

args = parser.parse_args()
pp_to_bin = get_pipeline_bin_division(args)
PP = args.pp_degree
args.pp_to_bin = pp_to_bin
f = partial(convert_checkpoint, args=args)
with Pool(PP) as p:
p.map(f, [i for i in range(PP)])
p.map(f, [i for i in range(PP)])
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def convert_checkpoint(p, args):
output_folder = output_folder + f"_pp_rank_{p:03d}"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
torch.save(out_model, f"{output_folder}/model_optim_rng.ckpt") #, (not master_only), global_master=True)
if args.is_xser:
from nemo.collections.nlp.parts.serialization import save
save(out_model, f"{output_folder}/model_optim_rng.ckpt")
else:
torch.save(out_model, f"{output_folder}/model_optim_rng.ckpt") #, (not master_only), global_master=True)
print("Done saving Megatron checkpoint")


Expand Down Expand Up @@ -285,11 +289,16 @@ def convert_checkpoint(p, args):
type=bool,
help="To use bias in the model layers",
)
parser.add_argument(
"--is_xser",
action="store_true",
help="Enable serialized saving",
)

args = parser.parse_args()
pp_to_bin = get_pipeline_bin_division(args)
PP = args.pp_degree
args.pp_to_bin = pp_to_bin
f = partial(convert_checkpoint, args=args)
with Pool(PP) as p:
p.map(f, [i for i in range(PP)])
p.map(f, [i for i in range(PP)])
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ def convert_checkpoint(args):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
print("saving nemo checkpoint")
torch.save(
checkpoint,
str(path)
+ "/megatron_gpt.ckpt",
)
if args.is_xser:
from nemo.collections.nlp.parts.serialization import save
save(checkpoint, str(path) + "/megatron_gpt.ckpt")
else:
torch.save(checkpoint, str(path) + "/megatron_gpt.ckpt")
print("Done saving nemo checkpoint")


Expand Down Expand Up @@ -211,5 +211,11 @@ def convert_checkpoint(args):
type=bool,
help="Share embedding and output layer weights.",
)
parser.add_argument(
"--is_xser",
action="store_true",
help="Enable serialized saving",
)

args = parser.parse_args()
convert_checkpoint(args)
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,12 @@ def convert_checkpoint(p):
output_folder = output_folder + f"_pp_rank_{p:03d}"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
torch.save(out_model,
f"{output_folder}/model_optim_rng.ckpt") # , (not master_only), global_master=True)
if args.is_xser:
from nemo.collections.nlp.parts.serialization import save
save(out_model, f"{output_folder}/model_optim_rng.ckpt")
else:
torch.save(out_model,
f"{output_folder}/model_optim_rng.ckpt") # , (not master_only), global_master=True)
print("Done saving Megatron checkpoint")


Expand Down Expand Up @@ -213,6 +217,11 @@ def convert_checkpoint(p):
type=bool,
help="Save weights in bf16.",
)
parser.add_argument(
"--is_xser",
action="store_true",
help="Enable serialized saving",
)

args = parser.parse_args()
convert_checkpoint(args)
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ def convert_checkpoint(p, args, config):
output_folder = output_folder + f"_pp_rank_{p:03d}"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
torch.save(out_model,
f"{output_folder}/model_optim_rng.ckpt") # , (not master_only), global_master=True)
if args.is_xser:
from nemo.collections.nlp.parts.serialization import save
save(out_model, f"{output_folder}/model_optim_rng.ckpt")
else:
torch.save(out_model,
f"{output_folder}/model_optim_rng.ckpt") # , (not master_only), global_master=True)
print("Done saving Megatron checkpoint")


Expand Down Expand Up @@ -269,6 +273,11 @@ def convert_checkpoint(p, args, config):
type=int,
help="Number of shards in the save checkpoint",
)
parser.add_argument(
"--is_xser",
action="store_true",
help="Enable serialized saving",
)

args = parser.parse_args()

Expand All @@ -279,4 +288,4 @@ def convert_checkpoint(p, args, config):
PP = args.pp_degree
f = partial(convert_checkpoint, args=args, config=config)
with Pool(PP) as p:
p.map(f, [i for i in range(PP)])
p.map(f, [i for i in range(PP)])
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def get_checkpoints_for_pp(pp: int, path_to_checkpoints: str, PP: int=1, TP: int
template = join(path_to_checkpoints, pp_str, '*.ckpt')

tp_paths = sorted(glob(template))
return {i: xser.load(p)['state_dict'] if is_xser else torch.load(p)['state_dict'] for i, p in enumerate(tp_paths)}
if is_xser:
import nemo.collections.nlp.parts.serialization as nser
load_fn = lambda path: nser.load(path, cpu_only=True)
else:
load_fn = torch.load
return {i: load_fn(p)['state_dict'] for i, p in enumerate(tp_paths)}


def get_checkpoints_for_tp(tp: int, path_to_checkpoints: str, is_xser: bool=False):
Expand All @@ -83,7 +88,12 @@ def get_checkpoints_for_tp(tp: int, path_to_checkpoints: str, is_xser: bool=Fals
template = join(path_to_checkpoints, f'tp_rank_{tp_str}_pp_rank_*', '*.ckpt')

pp_paths = sorted(glob(template))
return {i: xser.load(p)['state_dict'] if is_xser else torch.load(p)['state_dict'] for i, p in enumerate(pp_paths)}
if is_xser:
import nemo.collections.nlp.parts.serialization as nser
load_fn = lambda path: nser.load(path, cpu_only=True)
else:
load_fn = torch.load
return {i: load_fn(p)['state_dict'] for i, p in enumerate(pp_paths)}

def _get_nemo_key(k, nemo_key = 'model.language_model.'):
if "final_layernorm" in k:
Expand Down Expand Up @@ -205,8 +215,8 @@ def convert_checkpoint(config_file,
)
parser.add_argument(
"--is_xser",
default=False,
type=bool
action="store_true",
help="Enable serialized loading",
)
args = parser.parse_args()
convert_checkpoint(args.config_file, args.path_to_checkpoints, args.output_path, args.checkpoint_version, args.is_xser)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re
import numpy as np
import torch
import torch_xla.utils.serialization as xser

def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
# Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
Expand Down Expand Up @@ -81,7 +80,12 @@ def get_checkpoints_for_pp(pp: int, path_to_checkpoints: str, PP: int=1, TP: int
template = join(path_to_checkpoints, pp_str, f'*megatron_llama--step={max_step_recorded}*.ckpt')

tp_paths = sorted(glob(template))
return {i: xser.load(p)['state_dict'] if is_xser else torch.load(p)['state_dict'] for i, p in enumerate(tp_paths)}
if is_xser:
import nemo.collections.nlp.parts.serialization as nser
load_fn = lambda path: nser.load(path, cpu_only=True)
else:
load_fn = torch.load
return {i: load_fn(p)['state_dict'] for i, p in enumerate(tp_paths)}


def get_checkpoints_for_tp(tp: int, path_to_checkpoints: str, is_xser: bool=False):
Expand All @@ -92,7 +96,12 @@ def get_checkpoints_for_tp(tp: int, path_to_checkpoints: str, is_xser: bool=Fals
template = join(path_to_checkpoints, f'tp_rank_{tp_str}_pp_rank_*', '*.ckpt')

pp_paths = sorted(glob(template))
return {i: xser.load(p)['state_dict'] if is_xser else torch.load(p)['state_dict'] for i, p in enumerate(pp_paths)}
if is_xser:
import nemo.collections.nlp.parts.serialization as nser
load_fn = lambda path: nser.load(path, cpu_only=True)
else:
load_fn = torch.load
return {i: load_fn(p)['state_dict'] for i, p in enumerate(pp_paths)}

def _get_nemo_key(k, nemo_key = 'model.language_model.'):
if "final_layernorm" in k:
Expand Down Expand Up @@ -218,8 +227,8 @@ def convert_checkpoint(config_file,
)
parser.add_argument(
"--is_xser",
default=False,
type=bool
action="store_true",
help="Enable serialized loading",
)
args = parser.parse_args()
convert_checkpoint(args.config_file, args.path_to_checkpoints, args.output_path, args.checkpoint_version, args.is_xser)
convert_checkpoint(args.config_file, args.path_to_checkpoints, args.output_path, args.checkpoint_version, args.is_xser)
Loading