Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions lm_engine/data/megatron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Split(Enum):
test = 2


_HELPERS = None


def compile_helpers() -> None:
"""Compile C++ helper functions at runtime. Make sure this is invoked on a single process."""

Expand All @@ -24,37 +27,40 @@ def compile_helpers() -> None:
build_directory = os.path.join(os.path.dirname(__file__), "build")
os.makedirs(build_directory, exist_ok=True)

if ProcessGroupManager.get_global_rank() == 0:
load_cpp_extension(
def _compile():
global _HELPERS
_HELPERS = load_cpp_extension(
"helpers",
sources=os.path.join(os.path.dirname(__file__), "helpers.cpp"),
extra_cflags=["-O3", "-Wall", "-shared", "-std=c++11", "-fPIC", "-fdiagnostics-color"],
build_directory=build_directory,
verbose=True,
)

if ProcessGroupManager.get_global_rank() == 0:
_compile()

Communication.barrier()

if ProcessGroupManager.get_global_rank() != 0:
_compile()


def build_blending_indices(
dataset_index: np.ndarray, dataset_sample_index: np.ndarray, weights: list[float], num_datasets: int, size: int
) -> None:
import helpers

helpers.build_blending_indices(dataset_index, dataset_sample_index, weights, num_datasets, size)
_HELPERS.build_blending_indices(dataset_index, dataset_sample_index, weights, num_datasets, size)


def build_sample_idx(
sizes: np.ndarray, doc_idx: np.ndarray, sequence_length: int, num_epochs: int, tokens_per_epoch: int
) -> np.ndarray:
import helpers

if doc_idx.dtype == np.int32:
log_rank_0(logging.INFO, f"using int32 for sample idx")
sample_idx = helpers.build_sample_idx_int32(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
sample_idx = _HELPERS.build_sample_idx_int32(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
elif doc_idx.dtype == np.int64:
log_rank_0(logging.INFO, f"using int64 for sample idx")
sample_idx = helpers.build_sample_idx_int64(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
sample_idx = _HELPERS.build_sample_idx_int64(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
else:
raise ValueError("unexpected dtype for doc_idx")

Expand Down
3 changes: 1 addition & 2 deletions lm_engine/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
from torch.distributed import ReduceOp
from transformers import AutoConfig

from .enums import GradientCheckpointingMethod
from .hf_models import CommonConfig, is_custom_model
Expand Down Expand Up @@ -86,7 +85,7 @@ def _get_attention_flops(batch_size: int, sequence_length: int, hidden_size: int


def get_model_tflops(
config: AutoConfig | CommonConfig,
config: CommonConfig,
batch_size: int,
sequence_length: int,
gradient_checkpointing_method: GradientCheckpointingMethod | None,
Expand Down