You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/guides/pipelining.md
+53-26Lines changed: 53 additions & 26 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -6,14 +6,12 @@ As large language models continue to grow in size, training and fine-tuning them
6
6
7
7
Pipeline parallelism addresses these challenges by splitting a model's layers across different devices and processing them in a pipelined fashion. Each device processes a different stage of the model, enabling training of models that wouldn't fit on a single device while maintaining high GPU utilization through overlapped computation.
8
8
9
-
AutoPipeline is NeMo AutoModel's high-level pipeline parallelism interface specifically designed for HuggingFace models, making pipeline parallelism as simple as data parallelism. Built on PyTorch's native `torch.distributed.pipelining`, AutoPipeline provides seamless pipeline parallelism support for any HuggingFace decoder-only causal language model with minimal code changes.
9
+
AutoPipeline is NeMo AutoModel's high-level pipeline parallelism interface specifically designed for Hugging Face models, making pipeline parallelism as simple as data parallelism. Built on PyTorch's native `torch.distributed.pipelining`, AutoPipeline provides seamless pipeline parallelism support for any Hugging Face decoder-only causal language model with minimal code changes.
10
10
11
11
For custom models and more granular control, the functional API in `nemo_automodel.components.distributed.pipelining.functional` provides modular, accessible building blocks that can be used with any PyTorch model architecture.
12
12
13
-
This guide walks you through the complete process of using AutoPipeline for HuggingFace models and the functional API for custom models. You'll learn how to configure pipeline stages, integrate with existing training workflows, optimize performance, and combine pipeline parallelism with other parallelization strategies.
13
+
This guide walks you through the complete process of using AutoPipeline for Hugging Face models and the functional API for custom models. You'll learn how to configure pipeline stages, integrate with existing training workflows, optimize performance, and combine pipeline parallelism with other parallelization strategies.
14
14
15
-
:::{important}
16
-
Before proceeding with this guide, please ensure that you have NeMo AutoModel installed on your machine.
### Model Patching (`patch_inner_model`, `patch_causal_lm_model`)
136
+
137
+
AutoPipeline splits a model by deep-copying it per stage and pruning away modules that don't belong to that stage. Many Hugging Face models assume the full module tree is present and return `ModelOutput` objects; after pruning, their original `forward()` often breaks (or returns objects that are awkward to pipeline).
138
+
139
+
These two flags switch AutoPipeline to lightweight, pipeline-friendly `forward()` implementations that return tensors (see `nemo_automodel.components.distributed.pipelining.hf_utils.patch_hf_model_for_pp`):
140
+
141
+
-**`patch_inner_model`**: patches the *decoder module* (`model.model` for `...ForCausalLM`, otherwise the module itself) so each stage can run even after pruning.
142
+
-**Stage 0** (has `embed_tokens`): takes token IDs and produces hidden states.
143
+
-**Middle stages** (no `embed_tokens`): take hidden states from the previous stage (via `inputs_embeds`, or a float tensor passed through `input_ids`) and produce hidden states.
144
+
- Handles sliced layer containers (e.g., `layers` becoming dict-like after stage pruning) and returns a **tensor** of hidden states so stages can be chained.
145
+
146
+
For compilation/performance, this patched forward prefers a precomputed `causal_mask_mapping` dict (it will fall back to computing masks and warn if you don't provide it).
147
+
148
+
-**`patch_causal_lm_model`**: patches the *`...ForCausalLM` wrapper* forward (the module that owns `lm_head`) so pipeline stages return tensors:
149
+
- Returns **hidden states** when `lm_head` is absent on that stage.
150
+
- Returns **logits** when `lm_head` is present (typically only the last stage).
151
+
- Supports `logits_to_keep` to compute logits for only the last `k` tokens.
152
+
153
+
Note: this is only used when the module you pipeline is a `...ForCausalLM`-style wrapper (i.e., it has a `.model` attribute). If you pass a base decoder module directly, `patch_causal_lm_model` typically has no effect.
154
+
155
+
#### When Should I Change These?
156
+
157
+
-**Leave both `True` (default)** for standard Hugging Face `AutoModelForCausalLM` / `...ForCausalLM` models. This is the common case and gives the expected behavior: token IDs -> hidden states -> logits across stages.
158
+
-**Set both `False`** when your model already has a pipeline-friendly forward (returns tensors and can accept hidden states when embeddings are absent) or it needs custom kwargs/paths that the HF patch doesn't preserve (common for NeMo AutoModel-native model implementations, packed-sequence/`thd` paths, extra args like `padding_mask`, etc.). Many benchmark configs for NeMo-native models do this (for example `examples/benchmark/configs/qwen3_moe_30b_torch.yaml`).
159
+
-**Set `patch_inner_model=False, patch_causal_lm_model=True`** when your inner model is already stage-friendly, but the wrapper forward still returns a `ModelOutput` and you only want the wrapper simplified to “hidden states or logits”.
160
+
161
+
If you disable `patch_causal_lm_model`, your last stage will typically output hidden states instead of logits; in that case, make sure your `loss_fn` (or your last-stage module) applies the LM head explicitly.
162
+
163
+
### Automatic vs. Manual Layer Distribution
137
164
138
165
AutoPipeline offers flexible control over how your model is split across pipeline stages:
139
166
@@ -251,13 +278,13 @@ Key observations:
251
278
252
279
## Using the Functional API for Custom Models
253
280
254
-
While AutoPipeline is specifically designed as a high-level interface for HuggingFace models, the functional API in `nemo_automodel.components.distributed.pipelining.functional` provides more modular and accessible building blocks that can be used with any PyTorch model, including custom architectures. This separation allows for cleaner code organization where AutoPipeline handles HuggingFace-specific optimizations while the functional module remains model-agnostic.
281
+
While AutoPipeline is specifically designed as a high-level interface for Hugging Face models, the functional API in `nemo_automodel.components.distributed.pipelining.functional` provides more modular and accessible building blocks that can be used with any PyTorch model, including custom architectures. This separation allows for cleaner code organization where AutoPipeline handles Hugging Face-specific optimizations while the functional module remains model-agnostic.
255
282
256
283
### Key Functional API Components
257
284
258
285
The functional API provides several utilities for building custom pipeline parallel systems:
259
286
260
-
#### 1. **Stage ID Calculation**
287
+
#### 1. Stage ID Calculation
261
288
```python
262
289
from nemo_automodel.components.distributed.pipelining.functional import stage_ids_this_rank
@@ -492,7 +519,7 @@ The functional API is designed to be more accessible and modular than AutoPipeli
492
519
4.**Flexibility**: The functional API gives you complete control over how models are split and parallelized
493
520
5.**Testing**: Start with a small model and verify correct splitting before scaling up
494
521
495
-
The functional module's modular design makes it easier to integrate pipeline parallelism into existing custom model training workflows without the HuggingFace-specific assumptions that AutoPipeline makes.
522
+
The functional module's modular design makes it easier to integrate pipeline parallelism into existing custom model training workflows without the Hugging Face-specific assumptions that AutoPipeline makes.
496
523
497
524
## Mixed Parallelism
498
525
@@ -623,7 +650,7 @@ autopipeline:
623
650
624
651
### Mixed Parallelism Examples
625
652
626
-
#### Pipeline + Data Parallelism (4 GPUs total)
653
+
#### Pipeline + Data Parallelism (4 GPUs Total)
627
654
```bash
628
655
uv run torchrun --nproc_per_node=4 examples/llm/finetune.py \
629
656
--config your_config.yaml \
@@ -632,7 +659,7 @@ uv run torchrun --nproc_per_node=4 examples/llm/finetune.py \
632
659
--dataloader.batch_size 16
633
660
```
634
661
635
-
#### Pipeline + Tensor Parallelism (4 GPUs total)
662
+
#### Pipeline + Tensor Parallelism (4 GPUs Total)
636
663
```bash
637
664
uv run torchrun --nproc_per_node=4 examples/llm/finetune.py \
638
665
--config your_config.yaml \
@@ -641,7 +668,7 @@ uv run torchrun --nproc_per_node=4 examples/llm/finetune.py \
641
668
--dataloader.batch_size 8
642
669
```
643
670
644
-
#### Full Hybrid: PP + DP + TP (8 GPUs total)
671
+
#### Full Hybrid: PP + DP + TP (8 GPUs Total)
645
672
```bash
646
673
uv run torchrun --nproc_per_node=8 examples/llm/finetune.py \
AutoPipeline and the functional API together provide a complete pipeline parallelism solution for both HuggingFace and custom models. AutoPipeline offers a high-level, optimized interface specifically for HuggingFace models, while the functional module provides modular, accessible building blocks for custom architectures.
748
+
AutoPipeline and the functional API together provide a complete pipeline parallelism solution for both Hugging Face and custom models. AutoPipeline offers a high-level, optimized interface specifically for Hugging Face models, while the functional module provides modular, accessible building blocks for custom architectures.
722
749
723
750
Key takeaways:
724
751
- Pipeline parallelism enables training of models too large for a single GPU
725
-
- AutoPipeline provides a simple API for HuggingFace models with powerful customization options
752
+
- AutoPipeline provides a simple API for Hugging Face models with powerful customization options
726
753
- The functional API offers modular components for implementing pipeline parallelism with any PyTorch model
727
754
- Both can be combined with other parallelization strategies for optimal performance
728
-
- Use built-in monitoring tools to understand and optimize your pipeline
755
+
- Use built-in monitoring tools to understand and optimize your pipeline
0 commit comments