@@ -16,6 +16,9 @@ class Split(Enum):
1616 test = 2
1717
1818
19+ _HELPERS = None
20+
21+
1922def compile_helpers () -> None :
2023 """Compile C++ helper functions at runtime. Make sure this is invoked on a single process."""
2124
@@ -24,37 +27,40 @@ def compile_helpers() -> None:
2427 build_directory = os .path .join (os .path .dirname (__file__ ), "build" )
2528 os .makedirs (build_directory , exist_ok = True )
2629
27- if ProcessGroupManager .get_global_rank () == 0 :
28- load_cpp_extension (
30+ def _compile ():
31+ global _HELPERS
32+ _HELPERS = load_cpp_extension (
2933 "helpers" ,
3034 sources = os .path .join (os .path .dirname (__file__ ), "helpers.cpp" ),
3135 extra_cflags = ["-O3" , "-Wall" , "-shared" , "-std=c++11" , "-fPIC" , "-fdiagnostics-color" ],
3236 build_directory = build_directory ,
3337 verbose = True ,
3438 )
3539
40+ if ProcessGroupManager .get_global_rank () == 0 :
41+ _compile ()
42+
3643 Communication .barrier ()
3744
45+ if ProcessGroupManager .get_global_rank () != 0 :
46+ _compile ()
47+
3848
3949def build_blending_indices (
4050 dataset_index : np .ndarray , dataset_sample_index : np .ndarray , weights : list [float ], num_datasets : int , size : int
4151) -> None :
42- import helpers
43-
44- helpers .build_blending_indices (dataset_index , dataset_sample_index , weights , num_datasets , size )
52+ _HELPERS .build_blending_indices (dataset_index , dataset_sample_index , weights , num_datasets , size )
4553
4654
4755def build_sample_idx (
4856 sizes : np .ndarray , doc_idx : np .ndarray , sequence_length : int , num_epochs : int , tokens_per_epoch : int
4957) -> np .ndarray :
50- import helpers
51-
5258 if doc_idx .dtype == np .int32 :
5359 log_rank_0 (logging .INFO , f"using int32 for sample idx" )
54- sample_idx = helpers .build_sample_idx_int32 (sizes , doc_idx , sequence_length , num_epochs , tokens_per_epoch )
60+ sample_idx = _HELPERS .build_sample_idx_int32 (sizes , doc_idx , sequence_length , num_epochs , tokens_per_epoch )
5561 elif doc_idx .dtype == np .int64 :
5662 log_rank_0 (logging .INFO , f"using int64 for sample idx" )
57- sample_idx = helpers .build_sample_idx_int64 (sizes , doc_idx , sequence_length , num_epochs , tokens_per_epoch )
63+ sample_idx = _HELPERS .build_sample_idx_int64 (sizes , doc_idx , sequence_length , num_epochs , tokens_per_epoch )
5864 else :
5965 raise ValueError ("unexpected dtype for doc_idx" )
6066
0 commit comments