|
| 1 | +# Helix Parallelism |
| 2 | + |
| 3 | +Helix is a context parallelism (CP) technique for the decode/generation phase of LLM inference. Unlike traditional attention-FFN disaggregation (AFD) techniques, which spatially separate attention and FFN blocks onto different GPUs, Helix temporally separates them by reconfiguring the same GPUs. |
| 4 | + |
| 5 | +For all details, see the original paper: |
| 6 | +[Helix Parallelism: Rethinking Sharding Strategies for |
| 7 | +Interactive Multi-Million-Token LLM Decoding](https://arxiv.org/pdf/2507.07120) |
| 8 | + |
| 9 | +## How Helix Works |
| 10 | + |
| 11 | +In Helix parallelism: |
| 12 | + |
| 13 | +- **KV cache distribution**: The KV cache is partitioned across CP ranks during generation, with each rank responsible for a portion of the cached context |
| 14 | +- **Attention computation**: Each rank computes partial attention over its local KV cache shard |
| 15 | +- **Attention postprocessing**: Partial results are combined / corrected across ranks to produce the final attention output |
| 16 | +- **FFN layers**: CP ranks are repurposed as tensor parallelism (TP) ranks for FFN/MoE layers, maximizing GPU utilization |
| 17 | + |
| 18 | +## When to Use Helix |
| 19 | + |
| 20 | +Helix parallelism provides performance benefits when **all** of the following conditions apply: |
| 21 | + |
| 22 | +1. **Disaggregated serving**: Helix is designed for generation servers in a disaggregated (prefill/decode split) deployment architecture |
| 23 | +2. **Long input sequences**: Performance gains typically appear with input sequence lengths **>64K tokens** or more |
| 24 | +3. **Low batch sizes**: Optimal for latency-sensitive workloads with high tokens/second/user requirements |
| 25 | + |
| 26 | +On a typical latency vs. throughput Pareto curve, Helix targets operating points toward the right side (low latency, high per-user throughput). |
| 27 | + |
| 28 | +## Supported Models |
| 29 | + |
| 30 | +Helix parallelism currently supports models using **Multi-head Latent Attention (MLA)** on Blackwell GPU architecture: |
| 31 | + |
| 32 | +- DeepSeek-V3 / DeepSeek-V3-Lite |
| 33 | + |
| 34 | +## Configuration |
| 35 | + |
| 36 | +### Configuration Parameters |
| 37 | + |
| 38 | +Please set the following parameters for the generation servers in disaggregated mode. Example can be seen in the e2e accuracy test mentioned below. |
| 39 | + |
| 40 | +| Parameter | Description | Required | |
| 41 | +|-----------|-------------|----------| |
| 42 | +| `context_parallel_size` | Number of GPUs for context parallelism (≥2 for Helix) | Yes | |
| 43 | +| `cp_config.cp_type` | Must be `"HELIX"` or `CpType.HELIX` | Yes | |
| 44 | +| `cp_config.tokens_per_block` | Tokens per KV cache block | Yes | |
| 45 | +| `kv_cache_config.tokens_per_block` | Must match `cp_config.tokens_per_block` | Yes | |
| 46 | + |
| 47 | +### JSON Configuration (for YAML/JSON configs) |
| 48 | + |
| 49 | +```json |
| 50 | +{ |
| 51 | + "context_parallel_size": 2, |
| 52 | + "cp_config": { |
| 53 | + "cp_type": "HELIX", |
| 54 | + "tokens_per_block": 32 |
| 55 | + }, |
| 56 | + "kv_cache_config": { |
| 57 | + "tokens_per_block": 32 |
| 58 | + } |
| 59 | +} |
| 60 | +``` |
| 61 | + |
| 62 | +## Testing Helix with TensorRT-LLM |
| 63 | + |
| 64 | +### Unit Test: MLA Module Correctness |
| 65 | + |
| 66 | +The simplest correctness test validates the [MLA attention module](../../../tensorrt_llm/_torch/modules/attention.py) with Helix enabled: |
| 67 | + |
| 68 | +```bash |
| 69 | +# Run the MLA Helix unit test |
| 70 | +pytest tests/unittest/_torch/modules/test_mla_helix.py -v |
| 71 | +``` |
| 72 | + |
| 73 | +This test verifies that attention outputs match between single-GPU and Helix-parallelized execution. |
| 74 | + |
| 75 | +### End-to-End Accuracy test |
| 76 | + |
| 77 | +For end-to-end validation, the accuracy benchmark evaluates DeepSeek-V3-Lite in disaggregated mode on MMLU and GSM8K benchmarks: |
| 78 | + |
| 79 | +Test location: `tests/integration/defs/accuracy/test_disaggregated_serving.py` |
| 80 | +Test name: `TestDeepSeekV3Lite::test_auto_dtype_with_helix` |
| 81 | + |
| 82 | +This test demonstrates proper disaggregated server configuration with Helix. |
0 commit comments