1818including LLaMA, Qwen, Gemma3, and Ministral3 models.
1919"""
2020
21- from typing import Callable , Dict , Union , cast
21+ from typing import Callable , Dict , Optional , Union , cast
2222
2323import torch
2424from torch import nn
@@ -127,26 +127,30 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
127127 return RowwiseParallel ._prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh )
128128
129129
130- class FusedQKVColwiseParallel (ColwiseParallel ):
130+ class FusedColwiseParallel (ColwiseParallel ):
131131 """Column-wise parallelism for fused Q/K/V (or Q/KV) linear projections.
132132
133- A fused QKV linear has weight shape ``(num_sections * hidden, hidden)``
134- whose output is the concatenation ``[Q | K | V]``. Standard
135- ``ColwiseParallel`` shards the first dimension contiguously, which
136- crosses Q/K/V boundaries and mixes heads from different projections on
137- different TP ranks.
133+ A fused QKV linear has weight whose output is the concatenation
134+ ``[Q | K | V]``. Standard ``ColwiseParallel`` shards the first
135+ dimension contiguously, which crosses Q/K/V boundaries and mixes
136+ heads from different projections on different TP ranks.
138137
139- ``FusedQKVColwiseParallel `` instead splits the weight (and bias) into
140- ``num_sections`` equal blocks and shards **each block ** independently
141- with ``Shard(0)``, then concatenates the local shards. This ensures
142- every rank receives the correct head subset from every section.
138+ ``FusedColwiseParallel `` instead splits the weight (and bias) into
139+ sections and shards **each section ** independently with ``Shard(0)``,
140+ then concatenates the local shards. This ensures every rank receives
141+ the correct head subset from every section.
143142
144- The forward output is ``(B, T, num_sections * hidden/tp)`` with
145- per-section layout ``[Q_local | K_local | V_local]``. The standard
146- reshape ``view(B, T, num_sections, -1, head_dim)`` followed by
147- ``unbind(dim=2)`` therefore produces correct local Q, K, V tensors.
143+ Sections can be either:
148144
149- YAML string: ``"fused_qkv_colwise"``
145+ - **Equal-sized** (default): specified via ``num_sections``
146+ (e.g. ``num_sections=3`` for MHA where Q, K, V have the same size,
147+ or ``num_sections=2`` for fused gate_up projections).
148+ - **Variable-sized**: specified via ``section_sizes``
149+ (e.g. ``section_sizes=(q_size, kv_size, kv_size)`` for GQA where Q
150+ has more heads than K/V). When provided, takes precedence over
151+ ``num_sections``.
152+
153+ YAML string: ``"fused_colwise"``
150154
151155 Note on checkpointing
152156 ---------------------
@@ -159,32 +163,53 @@ class FusedQKVColwiseParallel(ColwiseParallel):
159163 QKV weight.
160164
161165 Args:
162- num_sections: Number of fused sections. Default ``3`` for Q/K/V.
166+ num_sections: Number of equal-sized fused sections. Default ``3``
167+ for Q/K/V. Ignored when ``section_sizes`` is provided.
168+ section_sizes: Explicit per-section sizes along dim-0. Each
169+ section must be independently divisible by the TP world size.
170+ When provided, takes precedence over ``num_sections``.
163171 """
164172
165- def __init__ (self , * , num_sections : int = 3 , ** kwargs ):
173+ def __init__ (
174+ self ,
175+ * ,
176+ num_sections : int = 3 ,
177+ section_sizes : Optional [tuple [int , ...]] = None ,
178+ ** kwargs ,
179+ ):
166180 super ().__init__ (** kwargs )
167181 self .num_sections = num_sections
182+ self .section_sizes = section_sizes
168183
169184 # -- custom partition function ------------------------------------
170185 def _partition_linear_fn (self , name , module , device_mesh ):
171186 from torch .distributed .tensor import distribute_tensor
172187
173- ns = self .num_sections
174188 for pname , param in list (module .named_parameters ()):
175189 if isinstance (param , DTensor ):
176190 continue # already distributed
177191
178192 dim0 = param .shape [0 ]
179- if dim0 % ns != 0 :
180- raise ValueError (
181- f"FusedQKVColwiseParallel: parameter '{ pname } ' dim-0 "
182- f"({ dim0 } ) is not divisible by num_sections ({ ns } )."
183- )
184193
185- # Split into sections, distribute each with Shard(0),
186- # concatenate the local shards.
187- sections = param .data .chunk (ns , dim = 0 )
194+ if self .section_sizes is not None :
195+ if sum (self .section_sizes ) != dim0 :
196+ raise ValueError (
197+ f"FusedColwiseParallel: parameter '{ pname } ' dim-0 "
198+ f"({ dim0 } ) does not match sum of section_sizes "
199+ f"({ self .section_sizes } , sum={ sum (self .section_sizes )} )."
200+ )
201+ sections = param .data .split (list (self .section_sizes ), dim = 0 )
202+ else :
203+ ns = self .num_sections
204+ if dim0 % ns != 0 :
205+ raise ValueError (
206+ f"FusedColwiseParallel: parameter '{ pname } ' dim-0 "
207+ f"({ dim0 } ) is not divisible by num_sections ({ ns } )."
208+ )
209+ sections = param .data .chunk (ns , dim = 0 )
210+
211+ # Distribute each section with Shard(0), concatenate the
212+ # local shards.
188213 local_parts = []
189214 for sec in sections :
190215 dt = distribute_tensor (sec , device_mesh , [Shard (0 )])
@@ -293,13 +318,22 @@ def _parallelize_llama(
293318 sequence_parallel : bool = False ,
294319) -> dict [str , ParallelStyle ]:
295320 """Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions."""
321+ # Compute per-section sizes for the fused QKV projection (GQA-aware).
322+ # For GQA, Q has more heads than K/V so the sections are unequal;
323+ # FusedColwiseParallel shards each section independently.
324+ head_dim = getattr (model .config , "head_dim" , model .config .hidden_size // model .config .num_attention_heads )
325+ q_size = model .config .num_attention_heads * head_dim
326+ kv_size = model .config .num_key_value_heads * head_dim
327+
296328 base_model_tp_plan : dict [str , ParallelStyle ] = {
297329 "model.embed_tokens" : RowwiseParallel (input_layouts = Replicate ()),
298330 "model.layers.*.self_attn.q_proj" : ColwiseParallel (),
299331 "model.layers.*.self_attn.k_proj" : ColwiseParallel (),
300332 "model.layers.*.self_attn.v_proj" : ColwiseParallel (),
301- "model.layers.*.self_attn.qkv_proj" : ColwiseParallel (), # Combined QKV projection
302- "model.layers.*.mlp.gate_up_proj" : ColwiseParallel (), # Fused gate and up projection
333+ "model.layers.*.self_attn.qkv_proj" : FusedColwiseParallel (
334+ section_sizes = (q_size , kv_size , kv_size ),
335+ ),
336+ "model.layers.*.mlp.gate_up_proj" : FusedColwiseParallel (num_sections = 2 ),
303337 "model.layers.*.self_attn.o_proj" : RowwiseParallel (),
304338 "model.layers.*.mlp.up_proj" : ColwiseParallel (),
305339 "model.layers.*.mlp.gate_proj" : ColwiseParallel (),
0 commit comments