Skip to content

Commit ed66d2a

Browse files
committed
rename FusedQKV to Fused
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 1dd4550 commit ed66d2a

File tree

2 files changed

+87
-35
lines changed

2 files changed

+87
-35
lines changed

nemo_automodel/components/distributed/optimized_tp_plans.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
including LLaMA, Qwen, Gemma3, and Ministral3 models.
1919
"""
2020

21-
from typing import Callable, Dict, Union, cast
21+
from typing import Callable, Dict, Optional, Union, cast
2222

2323
import torch
2424
from torch import nn
@@ -127,26 +127,30 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
127127
return RowwiseParallel._prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh)
128128

129129

130-
class FusedQKVColwiseParallel(ColwiseParallel):
130+
class FusedColwiseParallel(ColwiseParallel):
131131
"""Column-wise parallelism for fused Q/K/V (or Q/KV) linear projections.
132132
133-
A fused QKV linear has weight shape ``(num_sections * hidden, hidden)``
134-
whose output is the concatenation ``[Q | K | V]``. Standard
135-
``ColwiseParallel`` shards the first dimension contiguously, which
136-
crosses Q/K/V boundaries and mixes heads from different projections on
137-
different TP ranks.
133+
A fused QKV linear has weight whose output is the concatenation
134+
``[Q | K | V]``. Standard ``ColwiseParallel`` shards the first
135+
dimension contiguously, which crosses Q/K/V boundaries and mixes
136+
heads from different projections on different TP ranks.
138137
139-
``FusedQKVColwiseParallel`` instead splits the weight (and bias) into
140-
``num_sections`` equal blocks and shards **each block** independently
141-
with ``Shard(0)``, then concatenates the local shards. This ensures
142-
every rank receives the correct head subset from every section.
138+
``FusedColwiseParallel`` instead splits the weight (and bias) into
139+
sections and shards **each section** independently with ``Shard(0)``,
140+
then concatenates the local shards. This ensures every rank receives
141+
the correct head subset from every section.
143142
144-
The forward output is ``(B, T, num_sections * hidden/tp)`` with
145-
per-section layout ``[Q_local | K_local | V_local]``. The standard
146-
reshape ``view(B, T, num_sections, -1, head_dim)`` followed by
147-
``unbind(dim=2)`` therefore produces correct local Q, K, V tensors.
143+
Sections can be either:
148144
149-
YAML string: ``"fused_qkv_colwise"``
145+
- **Equal-sized** (default): specified via ``num_sections``
146+
(e.g. ``num_sections=3`` for MHA where Q, K, V have the same size,
147+
or ``num_sections=2`` for fused gate_up projections).
148+
- **Variable-sized**: specified via ``section_sizes``
149+
(e.g. ``section_sizes=(q_size, kv_size, kv_size)`` for GQA where Q
150+
has more heads than K/V). When provided, takes precedence over
151+
``num_sections``.
152+
153+
YAML string: ``"fused_colwise"``
150154
151155
Note on checkpointing
152156
---------------------
@@ -159,32 +163,53 @@ class FusedQKVColwiseParallel(ColwiseParallel):
159163
QKV weight.
160164
161165
Args:
162-
num_sections: Number of fused sections. Default ``3`` for Q/K/V.
166+
num_sections: Number of equal-sized fused sections. Default ``3``
167+
for Q/K/V. Ignored when ``section_sizes`` is provided.
168+
section_sizes: Explicit per-section sizes along dim-0. Each
169+
section must be independently divisible by the TP world size.
170+
When provided, takes precedence over ``num_sections``.
163171
"""
164172

165-
def __init__(self, *, num_sections: int = 3, **kwargs):
173+
def __init__(
174+
self,
175+
*,
176+
num_sections: int = 3,
177+
section_sizes: Optional[tuple[int, ...]] = None,
178+
**kwargs,
179+
):
166180
super().__init__(**kwargs)
167181
self.num_sections = num_sections
182+
self.section_sizes = section_sizes
168183

169184
# -- custom partition function ------------------------------------
170185
def _partition_linear_fn(self, name, module, device_mesh):
171186
from torch.distributed.tensor import distribute_tensor
172187

173-
ns = self.num_sections
174188
for pname, param in list(module.named_parameters()):
175189
if isinstance(param, DTensor):
176190
continue # already distributed
177191

178192
dim0 = param.shape[0]
179-
if dim0 % ns != 0:
180-
raise ValueError(
181-
f"FusedQKVColwiseParallel: parameter '{pname}' dim-0 "
182-
f"({dim0}) is not divisible by num_sections ({ns})."
183-
)
184193

185-
# Split into sections, distribute each with Shard(0),
186-
# concatenate the local shards.
187-
sections = param.data.chunk(ns, dim=0)
194+
if self.section_sizes is not None:
195+
if sum(self.section_sizes) != dim0:
196+
raise ValueError(
197+
f"FusedColwiseParallel: parameter '{pname}' dim-0 "
198+
f"({dim0}) does not match sum of section_sizes "
199+
f"({self.section_sizes}, sum={sum(self.section_sizes)})."
200+
)
201+
sections = param.data.split(list(self.section_sizes), dim=0)
202+
else:
203+
ns = self.num_sections
204+
if dim0 % ns != 0:
205+
raise ValueError(
206+
f"FusedColwiseParallel: parameter '{pname}' dim-0 "
207+
f"({dim0}) is not divisible by num_sections ({ns})."
208+
)
209+
sections = param.data.chunk(ns, dim=0)
210+
211+
# Distribute each section with Shard(0), concatenate the
212+
# local shards.
188213
local_parts = []
189214
for sec in sections:
190215
dt = distribute_tensor(sec, device_mesh, [Shard(0)])
@@ -293,13 +318,22 @@ def _parallelize_llama(
293318
sequence_parallel: bool = False,
294319
) -> dict[str, ParallelStyle]:
295320
"""Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions."""
321+
# Compute per-section sizes for the fused QKV projection (GQA-aware).
322+
# For GQA, Q has more heads than K/V so the sections are unequal;
323+
# FusedColwiseParallel shards each section independently.
324+
head_dim = getattr(model.config, "head_dim", model.config.hidden_size // model.config.num_attention_heads)
325+
q_size = model.config.num_attention_heads * head_dim
326+
kv_size = model.config.num_key_value_heads * head_dim
327+
296328
base_model_tp_plan: dict[str, ParallelStyle] = {
297329
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
298330
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
299331
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
300332
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
301-
"model.layers.*.self_attn.qkv_proj": ColwiseParallel(), # Combined QKV projection
302-
"model.layers.*.mlp.gate_up_proj": ColwiseParallel(), # Fused gate and up projection
333+
"model.layers.*.self_attn.qkv_proj": FusedColwiseParallel(
334+
section_sizes=(q_size, kv_size, kv_size),
335+
),
336+
"model.layers.*.mlp.gate_up_proj": FusedColwiseParallel(num_sections=2),
303337
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
304338
"model.layers.*.mlp.up_proj": ColwiseParallel(),
305339
"model.layers.*.mlp.gate_proj": ColwiseParallel(),

nemo_automodel/components/distributed/parallelizer.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ def _is_transformers_v5_or_higher() -> bool:
7676
)
7777
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
7878

79-
from nemo_automodel.components.distributed.optimized_tp_plans import PARALLELIZE_FUNCTIONS, VocabParallelEmbedding
79+
from nemo_automodel.components.distributed.optimized_tp_plans import (
80+
PARALLELIZE_FUNCTIONS,
81+
FusedColwiseParallel,
82+
VocabParallelEmbedding,
83+
)
8084
from nemo_automodel.components.distributed.parallel_styles import translate_to_lora
8185

8286
# TODO(boxiangw): Change to MegatronFSDP once it got published
@@ -632,10 +636,10 @@ def translate_to_torch_parallel_style(style: str):
632636
return RowwiseParallel(input_layouts=Replicate())
633637
elif style == "sequence_parallel":
634638
return SequenceParallel()
635-
elif style == "fused_qkv_colwise":
636-
from nemo_automodel.components.distributed.optimized_tp_plans import FusedQKVColwiseParallel
639+
elif style == "fused_colwise":
640+
from nemo_automodel.components.distributed.optimized_tp_plans import FusedColwiseParallel # noqa: F811
637641

638-
return FusedQKVColwiseParallel()
642+
return FusedColwiseParallel()
639643
else:
640644
raise ValueError(f"Unknown parallel style: {style}")
641645

@@ -928,14 +932,28 @@ def _get_parallel_plan(
928932
model_parallel_plan = get_hf_tp_shard_plan(model)
929933

930934
else:
935+
# Compute per-section sizes for fused QKV projection (GQA-aware).
936+
# Standard ColwiseParallel does a contiguous Shard(0) split that
937+
# crosses Q/K/V boundaries; FusedColwiseParallel shards each
938+
# section independently so every TP rank gets the correct heads.
939+
try:
940+
head_dim = getattr(
941+
model.config, "head_dim", model.config.hidden_size // model.config.num_attention_heads
942+
)
943+
q_size = model.config.num_attention_heads * head_dim
944+
kv_size = model.config.num_key_value_heads * head_dim
945+
qkv_style = FusedColwiseParallel(section_sizes=(q_size, kv_size, kv_size))
946+
except (AttributeError, TypeError, ZeroDivisionError):
947+
qkv_style = ColwiseParallel()
948+
931949
base_model_tp_plan = {
932950
"model.embed_tokens": VocabParallelEmbedding(input_layouts=Replicate()),
933951
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
934952
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
935953
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
936-
"model.layers.*.self_attn.qkv_proj": ColwiseParallel(), # Combined QKV projection
954+
"model.layers.*.self_attn.qkv_proj": qkv_style,
937955
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
938-
"model.layers.*.mlp.gate_up_proj": ColwiseParallel(), # Fused gate and up projection
956+
"model.layers.*.mlp.gate_up_proj": FusedColwiseParallel(num_sections=2),
939957
"model.layers.*.mlp.up_proj": ColwiseParallel(),
940958
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
941959
"model.layers.*.mlp.down_proj": RowwiseParallel(),

0 commit comments

Comments
 (0)