Skip to content

Commit 166ee24

Browse files
authored
[C++] fix helpers import (#364)
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
1 parent 7ff9d5b commit 166ee24

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

lm_engine/data/megatron/utils/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ class Split(Enum):
1616
test = 2
1717

1818

19+
_HELPERS = None
20+
21+
1922
def compile_helpers() -> None:
2023
"""Compile C++ helper functions at runtime. Make sure this is invoked on a single process."""
2124

@@ -24,37 +27,40 @@ def compile_helpers() -> None:
2427
build_directory = os.path.join(os.path.dirname(__file__), "build")
2528
os.makedirs(build_directory, exist_ok=True)
2629

27-
if ProcessGroupManager.get_global_rank() == 0:
28-
load_cpp_extension(
30+
def _compile():
31+
global _HELPERS
32+
_HELPERS = load_cpp_extension(
2933
"helpers",
3034
sources=os.path.join(os.path.dirname(__file__), "helpers.cpp"),
3135
extra_cflags=["-O3", "-Wall", "-shared", "-std=c++11", "-fPIC", "-fdiagnostics-color"],
3236
build_directory=build_directory,
3337
verbose=True,
3438
)
3539

40+
if ProcessGroupManager.get_global_rank() == 0:
41+
_compile()
42+
3643
Communication.barrier()
3744

45+
if ProcessGroupManager.get_global_rank() != 0:
46+
_compile()
47+
3848

3949
def build_blending_indices(
4050
dataset_index: np.ndarray, dataset_sample_index: np.ndarray, weights: list[float], num_datasets: int, size: int
4151
) -> None:
42-
import helpers
43-
44-
helpers.build_blending_indices(dataset_index, dataset_sample_index, weights, num_datasets, size)
52+
_HELPERS.build_blending_indices(dataset_index, dataset_sample_index, weights, num_datasets, size)
4553

4654

4755
def build_sample_idx(
4856
sizes: np.ndarray, doc_idx: np.ndarray, sequence_length: int, num_epochs: int, tokens_per_epoch: int
4957
) -> np.ndarray:
50-
import helpers
51-
5258
if doc_idx.dtype == np.int32:
5359
log_rank_0(logging.INFO, f"using int32 for sample idx")
54-
sample_idx = helpers.build_sample_idx_int32(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
60+
sample_idx = _HELPERS.build_sample_idx_int32(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
5561
elif doc_idx.dtype == np.int64:
5662
log_rank_0(logging.INFO, f"using int64 for sample idx")
57-
sample_idx = helpers.build_sample_idx_int64(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
63+
sample_idx = _HELPERS.build_sample_idx_int64(sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch)
5864
else:
5965
raise ValueError("unexpected dtype for doc_idx")
6066

lm_engine/train_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import torch
88
from torch.distributed import ReduceOp
9-
from transformers import AutoConfig
109

1110
from .enums import GradientCheckpointingMethod
1211
from .hf_models import CommonConfig, is_custom_model
@@ -86,7 +85,7 @@ def _get_attention_flops(batch_size: int, sequence_length: int, hidden_size: int
8685

8786

8887
def get_model_tflops(
89-
config: AutoConfig | CommonConfig,
88+
config: CommonConfig,
9089
batch_size: int,
9190
sequence_length: int,
9291
gradient_checkpointing_method: GradientCheckpointingMethod | None,

0 commit comments

Comments
 (0)