Skip to content

Commit b4ea6d8

Browse files
authored
Some data-prep optimizations for pretraining (#45)
Signed-off-by: Marc Romeyn <marcromeyn@gmail.com>
1 parent d50fc93 commit b4ea6d8

File tree

22 files changed

+3797
-172
lines changed

22 files changed

+3797
-172
lines changed

src/nemotron/data_prep/__init__.py

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,30 @@ class DataPrepConfig:
158158
sample: int | None = None
159159
"""Limit rows per dataset (for quick tests)"""
160160

161-
num_actors: int | None = None
162-
"""Ray actors for parallel processing (None = auto)"""
163-
164161
force: bool = False
165162
"""Force new run, ignoring cache"""
166163

167164
artifact_name: str | None = None
168165
"""Semantic artifact name (e.g., 'nano3/pretrain/data')"""
169166

167+
# Ray Data execution
168+
ray_data_enabled: bool = True
169+
"""Enable Ray Data executor for shard processing.
170+
Uses Ray Data's ActorPoolStrategy for automatic actor lifecycle
171+
management, resource accounting, and bottleneck metrics in W&B."""
172+
173+
ray_data_min_actors: int = 2
174+
"""Minimum actors for Ray Data executor (warm pool)"""
175+
176+
ray_data_max_actors: int | None = None
177+
"""Maximum actors for Ray Data executor (None = use all available CPUs)"""
178+
179+
ray_data_cpus_per_actor: float = 1.0
180+
"""CPUs per actor for Ray Data executor"""
181+
182+
ray_data_max_tasks_in_flight: int = 2
183+
"""Max tasks in flight per actor (pipelining depth)"""
184+
170185

171186
def run_data_prep(
172187
config: DataPrepConfig, *, artifact_class: type = PretrainBlendsArtifact
@@ -202,12 +217,6 @@ def run_data_prep(
202217
# Use object.__setattr__ since Dataset is a Pydantic model
203218
object.__setattr__(dataset, "text_field", config.text_field)
204219

205-
# Auto-detect num_actors from CPU count
206-
num_actors = config.num_actors
207-
if num_actors is None:
208-
cpu_count = os.cpu_count() or 4
209-
num_actors = max(2, min(32, cpu_count * 3 // 4))
210-
211220
# Build pipeline config
212221
# When sampling, use 1 shard to get exactly `sample` rows per dataset
213222
num_shards = config.num_shards
@@ -224,29 +233,7 @@ def run_data_prep(
224233
# Resolve output_dir to absolute path for W&B artifact storage
225234
output_dir = config.output_dir.resolve() if hasattr(config.output_dir, 'resolve') else Path(config.output_dir).resolve()
226235

227-
pipeline_config = PipelineConfig(
228-
output=OutputConfig(
229-
dir=output_dir,
230-
format=output_format,
231-
min_doc_chars=config.min_doc_chars,
232-
max_doc_tokens=config.max_doc_tokens,
233-
max_rows=config.sample,
234-
),
235-
tokenizer=TokenizerConfig(
236-
model=config.tokenizer_model,
237-
add_bos=config.add_bos,
238-
add_eos=config.add_eos,
239-
),
240-
num_actors=num_actors,
241-
force=config.force,
242-
split=config.split,
243-
per_split=config.per_split,
244-
)
245-
246-
# Initialize Ray with runtime_env excludes to prevent large directories from
247-
# being packaged. Without this, Ray auto-packages the working directory when
248-
# actors are created, which can exceed the 512MB GCS limit if output/ or other
249-
# large directories are present.
236+
# Initialize Ray early so we can query cluster resources
250237
import ray
251238

252239
if not ray.is_initialized():
@@ -268,6 +255,74 @@ def run_data_prep(
268255
}
269256
ray.init(address="auto", ignore_reinit_error=True, runtime_env=runtime_env)
270257

258+
# Build Ray Data config if enabled, auto-detecting cluster resources
259+
ray_data_config = None
260+
if config.ray_data_enabled:
261+
from nemotron.data_prep.config import RayDataConfig
262+
263+
# Auto-detect available CPUs from Ray cluster
264+
# Fallback chain: Ray cluster -> SLURM env var -> os.cpu_count()
265+
cluster_resources = ray.cluster_resources()
266+
ray_cpus = cluster_resources.get("CPU", 0)
267+
slurm_cpus = int(os.environ.get("SLURM_CPUS_PER_TASK", 0))
268+
os_cpus = os.cpu_count() or 4
269+
270+
# Use the highest available CPU count (Ray may report fewer due to config issues)
271+
available_cpus = max(int(ray_cpus), slurm_cpus, os_cpus)
272+
273+
# Use most of available CPUs for actors (leave some headroom)
274+
# min_actors = start with good parallelism
275+
# max_actors = allow scaling up to use all CPUs
276+
cpus_per_actor = config.ray_data_cpus_per_actor
277+
auto_max_actors = int(available_cpus * 0.9 / cpus_per_actor) # Use 90% of CPUs
278+
if config.ray_data_max_actors is not None:
279+
max_actors = min(config.ray_data_max_actors, auto_max_actors)
280+
else:
281+
max_actors = auto_max_actors
282+
min_actors = min(config.ray_data_min_actors, max_actors)
283+
284+
# Log resource detection for debugging
285+
print(f"Ray cluster resources: {cluster_resources}")
286+
print(f"CPU detection: Ray={ray_cpus}, SLURM={slurm_cpus}, os={os_cpus} -> using {available_cpus}")
287+
print(f"Ray Data config: min_actors={min_actors}, max_actors={max_actors}")
288+
289+
# Log W&B status for debugging
290+
try:
291+
import wandb
292+
if wandb.run is not None:
293+
print(f"[W&B] Active run: {wandb.run.name} (id={wandb.run.id})")
294+
else:
295+
print("[W&B] No active run - metrics will not be logged")
296+
except ImportError:
297+
print("[W&B] wandb not installed")
298+
299+
ray_data_config = RayDataConfig(
300+
enabled=True,
301+
min_actors=min_actors,
302+
max_actors=max_actors,
303+
cpus_per_actor=cpus_per_actor,
304+
max_tasks_in_flight_per_actor=config.ray_data_max_tasks_in_flight,
305+
)
306+
307+
pipeline_config = PipelineConfig(
308+
output=OutputConfig(
309+
dir=output_dir,
310+
format=output_format,
311+
min_doc_chars=config.min_doc_chars,
312+
max_doc_tokens=config.max_doc_tokens,
313+
max_rows=config.sample,
314+
),
315+
tokenizer=TokenizerConfig(
316+
model=config.tokenizer_model,
317+
add_bos=config.add_bos,
318+
add_eos=config.add_eos,
319+
),
320+
force=config.force,
321+
split=config.split,
322+
per_split=config.per_split,
323+
ray_data=ray_data_config,
324+
)
325+
271326
# Run processing pipeline
272327
result = last_mile_process(blend, pipeline_config)
273328

src/nemotron/data_prep/config.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,37 @@ def __post_init__(self) -> None:
193193
OutputFormat = BinIdxOutputConfig | JsonlOutputConfig | PackedOutputConfig | ChatSftOutputConfig
194194

195195

196+
@dataclass(frozen=True)
197+
class RayDataConfig:
198+
"""Configuration for Ray Data shard-task execution.
199+
200+
These settings map directly to Ray Data's ActorPoolStrategy and
201+
map_batches parameters, providing explicit control over resource usage.
202+
203+
When enabled, uses Ray Data's streaming executor for shard processing
204+
instead of manual actor pool management. Benefits include:
205+
- Automatic actor lifecycle management (no leaked actors)
206+
- Integrated backpressure with Ray's resource manager
207+
- Explicit CPU accounting per actor
208+
209+
Attributes:
210+
enabled: Enable Ray Data execution (vs legacy manual actors)
211+
min_actors: Minimum actors to keep alive (warm pool)
212+
max_actors: Maximum actors. None means use all available CPUs.
213+
cpus_per_actor: CPUs allocated per actor (explicit accounting)
214+
max_tasks_in_flight_per_actor: Pipelining depth to reduce scheduling
215+
bubbles and keep actors fed. Note: does not by itself parallelize
216+
a single actor; true I/O latency hiding requires either more actors
217+
(with fractional num_cpus) or async internal concurrency.
218+
"""
219+
220+
enabled: bool = False
221+
min_actors: int = 2
222+
max_actors: int | None = None # None = use all available CPUs
223+
cpus_per_actor: float = 1.0
224+
max_tasks_in_flight_per_actor: int = 2
225+
226+
196227
@dataclass(frozen=True)
197228
class OutputConfig:
198229
"""Output configuration.
@@ -250,22 +281,23 @@ class PipelineConfig:
250281
Attributes:
251282
output: Output settings
252283
tokenizer: Tokenizer settings (required for binidx/packed formats, optional for jsonl)
253-
num_actors: Number of Ray actors for parallel processing
254284
sample: Shard sampling spec ("10%", "5", or None for all)
255285
sample_seed: Random seed for sampling
256286
force: Force new run (ignore cached results)
257287
split: Split ratio for single-blend mode (e.g., "99990,8,2"). Deprecated.
258288
per_split: Per-split output configuration for Megatron-Bridge per_split_data_args_path
289+
ray_data: Ray Data execution configuration. When enabled and ray_data.enabled=True,
290+
uses Ray Data's ActorPoolStrategy for shard processing instead of manual actors.
259291
"""
260292

261293
output: OutputConfig
262294
tokenizer: TokenizerConfig | None = None
263-
num_actors: int = 4
264295
sample: str | int | None = None
265296
sample_seed: int = 42
266297
force: bool = False
267298
split: str | None = None # Deprecated - use per_split instead
268299
per_split: PerSplitConfig | None = None
300+
ray_data: RayDataConfig | None = None
269301

270302

271303
# ============================================================================

0 commit comments

Comments
 (0)