Skip to content

Commit ae07f11

Browse files
committed
add unit tests
Signed-off-by: Chen Cui <chcui@nvidia.com>
1 parent de67930 commit ae07f11

File tree

5 files changed

+672
-14
lines changed

5 files changed

+672
-14
lines changed

src/megatron/bridge/training/utils/packed_seq_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,17 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams:
4444
cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin")
4545

4646
if cu_seqlens_argmin is not None:
47-
cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()]
48-
assert cu_seqlens_padded[cu_seqlens_argmin.item()] == -1 # cu_seqlens padding is -1
47+
argmin_idx = cu_seqlens_argmin.item()
48+
assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1
49+
cu_seqlens_padded = cu_seqlens_padded[:argmin_idx]
4950
elif torch.min(cu_seqlens_padded) == -1:
5051
cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)]
5152

5253
if cu_seqlens_unpadded is not None:
5354
if cu_seqlens_unpadded_argmin is not None:
54-
cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
55-
assert cu_seqlens_unpadded[cu_seqlens_unpadded_argmin.item()] == -1 # cu_seqlens padding is -1
55+
argmin_idx = cu_seqlens_unpadded_argmin.item()
56+
assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1
57+
cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx]
5658
elif torch.min(cu_seqlens_unpadded) == -1:
5759
cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]
5860

src/megatron/bridge/training/vlm_step.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def pack_batch_sequences(
105105
position_ids: torch.Tensor,
106106
pad_token_id: int = 0,
107107
pad_to_multiple_of: int = 1,
108-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
108+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
109109
"""
110110
Pack sequences in a batch by concatenating them and removing padding.
111111
@@ -125,8 +125,7 @@ def pack_batch_sequences(
125125
- packed_attention_mask: None (not used with packing)
126126
- packed_position_ids: [1, total_len]
127127
- cu_seqlens: [num_sequences + 1] - cumulative sequence lengths
128-
- cu_seqlens_argmin: 0 (dummy)
129-
- max_seqlen: int - max sequence length in packed batch
128+
- max_seqlen: tensor - max sequence length in packed batch
130129
"""
131130
batch_size, seq_len = tokens.shape
132131
device = tokens.device
@@ -159,8 +158,7 @@ def pack_batch_sequences(
159158
attention_mask,
160159
position_ids[:1],
161160
torch.tensor([0, seq_len], dtype=torch.int32, device=device),
162-
0,
163-
seq_len,
161+
torch.tensor(seq_len, dtype=torch.int32, device=device),
164162
)
165163

166164
# Build cumulative sequence lengths

tests/unit_tests/models/ministral3/test_ministral3_provider.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import torch
1516

1617
from megatron.bridge.models.ministral3.ministral3_provider import (
1718
Ministral3ModelProvider,
@@ -172,3 +173,175 @@ def test_ministral3_14b_initialization(self):
172173
assert provider.ffn_hidden_size == 16384
173174
assert provider.num_layers == 40
174175
assert provider.rotary_base == 1000000000.0
176+
177+
178+
class TestGetLlama4AttnScale:
179+
"""Test cases for _get_llama_4_attn_scale function used in MinistralTEDotProductAttention.
180+
181+
This function computes attention scaling based on Llama 4 attention parameters.
182+
The key change in PR 1997 is that it now handles different query shapes for
183+
packed (3D) vs unpacked (4D) tensors.
184+
"""
185+
186+
def _get_llama_4_attn_scale(
187+
self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
188+
) -> torch.Tensor:
189+
"""Reimplementation of the function for testing."""
190+
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
191+
num_dims_to_add = len(query_shape) - 1
192+
for _ in range(num_dims_to_add):
193+
scaling = scaling.unsqueeze(-1)
194+
return scaling
195+
196+
def test_unpacked_4d_query_shape(self):
197+
"""Test attention scaling with unpacked 4D query shape [seq_len, batch, num_heads, head_dim]."""
198+
seq_len = 8
199+
batch_size = 2
200+
num_heads = 4
201+
head_dim = 64
202+
203+
positions_ids = torch.arange(seq_len)
204+
beta = 0.1
205+
max_position_embeddings = 16384
206+
query_shape = (seq_len, batch_size, num_heads, head_dim)
207+
208+
scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
209+
210+
# Output should have shape [seq_len, 1, 1, 1] for broadcasting
211+
assert scaling.shape == (seq_len, 1, 1, 1)
212+
213+
# First position should have scaling = 1 (since log(1 + 0) = 0)
214+
expected_first = 1 + beta * torch.log(torch.tensor(1.0))
215+
assert torch.isclose(scaling[0, 0, 0, 0], expected_first, atol=1e-6)
216+
217+
def test_packed_3d_query_shape(self):
218+
"""Test attention scaling with packed 3D query shape [seq_len, num_heads, head_dim]."""
219+
seq_len = 16
220+
num_heads = 8
221+
head_dim = 32
222+
223+
positions_ids = torch.arange(seq_len)
224+
beta = 0.2
225+
max_position_embeddings = 8192
226+
query_shape = (seq_len, num_heads, head_dim)
227+
228+
scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
229+
230+
# Output should have shape [seq_len, 1, 1] for broadcasting (3D - 1 = 2 dims added)
231+
assert scaling.shape == (seq_len, 1, 1)
232+
233+
# Verify scaling values are computed correctly
234+
expected = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
235+
assert torch.allclose(scaling.squeeze(), expected, atol=1e-6)
236+
237+
def test_scaling_formula_correctness(self):
238+
"""Test that the scaling formula matches expected Llama 4 attention scaling."""
239+
positions_ids = torch.tensor([0, 1, 100, 1000, 16384, 32768])
240+
beta = 0.15
241+
max_position_embeddings = 16384
242+
query_shape = (6, 1, 1, 1)
243+
244+
scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
245+
246+
# Manual computation of expected values
247+
# For position 0: 1 + 0.15 * log(1 + 0) = 1
248+
# For position 16384: 1 + 0.15 * log(1 + 1) = 1 + 0.15 * log(2)
249+
# For position 32768: 1 + 0.15 * log(1 + 2) = 1 + 0.15 * log(3)
250+
251+
expected_0 = 1.0
252+
expected_16384 = 1 + beta * torch.log(torch.tensor(2.0))
253+
expected_32768 = 1 + beta * torch.log(torch.tensor(3.0))
254+
255+
assert torch.isclose(scaling[0].squeeze(), torch.tensor(expected_0), atol=1e-6)
256+
assert torch.isclose(scaling[4].squeeze(), expected_16384, atol=1e-6)
257+
assert torch.isclose(scaling[5].squeeze(), expected_32768, atol=1e-6)
258+
259+
def test_beta_zero_returns_ones(self):
260+
"""Test that beta=0 returns all ones (no scaling)."""
261+
positions_ids = torch.arange(10)
262+
beta = 0.0
263+
max_position_embeddings = 4096
264+
query_shape = (10, 4, 64)
265+
266+
scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
267+
268+
assert torch.allclose(scaling.squeeze(), torch.ones(10), atol=1e-6)
269+
270+
def test_different_query_shapes_get_correct_dims(self):
271+
"""Test that different query shapes result in correct number of dimensions added."""
272+
positions_ids = torch.arange(4)
273+
beta = 0.1
274+
max_position_embeddings = 1000
275+
276+
# 2D query shape
277+
query_shape_2d = (4, 32)
278+
scaling_2d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_2d)
279+
assert scaling_2d.shape == (4, 1) # 2-1 = 1 dim added
280+
281+
# 3D query shape (packed THD)
282+
query_shape_3d = (4, 8, 32)
283+
scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_3d)
284+
assert scaling_3d.shape == (4, 1, 1) # 3-1 = 2 dims added
285+
286+
# 4D query shape (unpacked BSHD)
287+
query_shape_4d = (4, 2, 8, 32)
288+
scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_4d)
289+
assert scaling_4d.shape == (4, 1, 1, 1) # 4-1 = 3 dims added
290+
291+
def test_broadcasting_compatibility(self):
292+
"""Test that scaling tensor is broadcastable to query tensor."""
293+
seq_len = 8
294+
num_heads = 4
295+
head_dim = 64
296+
297+
positions_ids = torch.arange(seq_len)
298+
beta = 0.1
299+
max_position_embeddings = 16384
300+
301+
# Test for 3D packed format
302+
query_3d = torch.randn(seq_len, num_heads, head_dim)
303+
scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_3d.shape)
304+
305+
# Broadcasting should work
306+
result_3d = query_3d * scaling_3d.to(query_3d.dtype)
307+
assert result_3d.shape == query_3d.shape
308+
309+
# Test for 4D unpacked format
310+
batch = 2
311+
query_4d = torch.randn(seq_len, batch, num_heads, head_dim)
312+
scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_4d.shape)
313+
314+
# Broadcasting should work
315+
result_4d = query_4d * scaling_4d.to(query_4d.dtype)
316+
assert result_4d.shape == query_4d.shape
317+
318+
def test_gpu_tensor_support(self):
319+
"""Test that the function works with GPU tensors if available."""
320+
if not torch.cuda.is_available():
321+
return # Skip test if no GPU
322+
323+
positions_ids = torch.arange(8, device="cuda")
324+
beta = 0.1
325+
max_position_embeddings = 1024
326+
query_shape = (8, 4, 32)
327+
328+
scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
329+
330+
assert scaling.device.type == "cuda"
331+
assert scaling.shape == (8, 1, 1)
332+
333+
def test_dtype_preservation(self):
334+
"""Test that output dtype matches input positions_ids dtype."""
335+
positions_ids_float32 = torch.arange(4, dtype=torch.float32)
336+
positions_ids_float64 = torch.arange(4, dtype=torch.float64)
337+
beta = 0.1
338+
max_position_embeddings = 100
339+
query_shape = (4, 2, 8)
340+
341+
scaling_32 = self._get_llama_4_attn_scale(positions_ids_float32, beta, max_position_embeddings, query_shape)
342+
scaling_64 = self._get_llama_4_attn_scale(positions_ids_float64, beta, max_position_embeddings, query_shape)
343+
344+
# Note: torch.arange with int creates int tensors, but the function uses float operations
345+
# The scaling result will be float due to log operation
346+
assert scaling_32.dtype == torch.float32
347+
assert scaling_64.dtype == torch.float64

tests/unit_tests/training/test_gpt_step.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ def test_packed_seq_params_no_padding(self):
115115
# Verify the result is a PackedSeqParams object
116116
assert isinstance(result, PackedSeqParams)
117117

118-
# When there's no -1 padding, argmin returns 0 (index of min value)
119-
# So cu_seqlens[:0] returns empty tensor
120-
expected_cu_seqlens = torch.empty(0, dtype=torch.int32) # Empty tensor
118+
# When there's no -1 padding, the tensor is returned unchanged
119+
expected_cu_seqlens = torch.tensor([0, 7, 14], dtype=torch.int32)
121120
assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens)
122121
assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens)
123122

@@ -181,6 +180,111 @@ def test_packed_seq_params_all_fields_match(self):
181180
assert torch.equal(result.cu_seqlens_q, result.cu_seqlens_kv)
182181
assert torch.equal(result.max_seqlen_q, result.max_seqlen_kv)
183182

183+
def test_packed_seq_params_with_cu_seqlens_unpadded(self):
184+
"""Test functionality with cu_seqlens_unpadded for THD CP support."""
185+
# Padded cu_seqlens (includes padding for CP divisibility)
186+
cu_seqlens_padded = torch.tensor([[0, 8, 16, -1, -1]], dtype=torch.int32)
187+
# Unpadded cu_seqlens (actual sequence boundaries)
188+
cu_seqlens_unpadded = torch.tensor([[0, 6, 14, -1, -1]], dtype=torch.int32)
189+
190+
batch = {
191+
"cu_seqlens": cu_seqlens_padded,
192+
"cu_seqlens_unpadded": cu_seqlens_unpadded,
193+
"max_seqlen": torch.tensor([[10]], dtype=torch.int32),
194+
}
195+
196+
result = get_packed_seq_params(batch)
197+
198+
# cu_seqlens_q and cu_seqlens_kv should use unpadded values
199+
expected_unpadded = torch.tensor([0, 6, 14], dtype=torch.int32)
200+
assert torch.equal(result.cu_seqlens_q, expected_unpadded)
201+
assert torch.equal(result.cu_seqlens_kv, expected_unpadded)
202+
203+
# cu_seqlens_q_padded and cu_seqlens_kv_padded should use padded values
204+
expected_padded = torch.tensor([0, 8, 16], dtype=torch.int32)
205+
assert torch.equal(result.cu_seqlens_q_padded, expected_padded)
206+
assert torch.equal(result.cu_seqlens_kv_padded, expected_padded)
207+
208+
def test_packed_seq_params_cu_seqlens_unpadded_with_argmin(self):
209+
"""Test cu_seqlens_unpadded processing with argmin hint."""
210+
batch = {
211+
"cu_seqlens": torch.tensor([[0, 4, 8, 12, -1, -1]], dtype=torch.int32),
212+
"cu_seqlens_argmin": torch.tensor(4), # Index where -1 starts
213+
"cu_seqlens_unpadded": torch.tensor([[0, 3, 7, 10, -1, -1]], dtype=torch.int32),
214+
"cu_seqlens_unpadded_argmin": torch.tensor(4), # Index where -1 starts
215+
}
216+
217+
result = get_packed_seq_params(batch)
218+
219+
# Verify unpadded values are used for q/kv
220+
expected_unpadded = torch.tensor([0, 3, 7, 10], dtype=torch.int32)
221+
assert torch.equal(result.cu_seqlens_q, expected_unpadded)
222+
assert torch.equal(result.cu_seqlens_kv, expected_unpadded)
223+
224+
# Verify padded values are set for _padded fields
225+
expected_padded = torch.tensor([0, 4, 8, 12], dtype=torch.int32)
226+
assert torch.equal(result.cu_seqlens_q_padded, expected_padded)
227+
assert torch.equal(result.cu_seqlens_kv_padded, expected_padded)
228+
229+
def test_packed_seq_params_without_unpadded_fallback(self):
230+
"""Test fallback to cu_seqlens when cu_seqlens_unpadded is not provided."""
231+
batch = {
232+
"cu_seqlens": torch.tensor([[0, 5, 10, 15, -1]], dtype=torch.int32),
233+
"max_seqlen": torch.tensor([[8]], dtype=torch.int32),
234+
}
235+
236+
result = get_packed_seq_params(batch)
237+
238+
expected_cu_seqlens = torch.tensor([0, 5, 10, 15], dtype=torch.int32)
239+
240+
# Without unpadded, q/kv should use padded values
241+
assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens)
242+
assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens)
243+
244+
# Padded fields should match q/kv
245+
assert torch.equal(result.cu_seqlens_q_padded, expected_cu_seqlens)
246+
assert torch.equal(result.cu_seqlens_kv_padded, expected_cu_seqlens)
247+
248+
def test_packed_seq_params_no_padding_in_cu_seqlens(self):
249+
"""Test when cu_seqlens has no -1 padding markers."""
250+
batch = {
251+
"cu_seqlens": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1 padding
252+
"max_seqlen": torch.tensor([[7]], dtype=torch.int32),
253+
}
254+
255+
result = get_packed_seq_params(batch)
256+
257+
# When no -1 present and min != -1, the tensor should remain as-is
258+
expected = torch.tensor([0, 5, 10], dtype=torch.int32)
259+
assert torch.equal(result.cu_seqlens_q, expected)
260+
assert torch.equal(result.cu_seqlens_q_padded, expected)
261+
262+
def test_packed_seq_params_qkv_format_is_thd(self):
263+
"""Test that qkv_format is always set to 'thd'."""
264+
batch = {
265+
"cu_seqlens": torch.tensor([[0, 10, -1]], dtype=torch.int32),
266+
}
267+
268+
result = get_packed_seq_params(batch)
269+
270+
assert result.qkv_format == "thd"
271+
272+
def test_packed_seq_params_cu_seqlens_unpadded_no_padding(self):
273+
"""Test cu_seqlens_unpadded with no padding markers."""
274+
batch = {
275+
"cu_seqlens": torch.tensor([[0, 6, 12]], dtype=torch.int32),
276+
"cu_seqlens_unpadded": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1
277+
}
278+
279+
result = get_packed_seq_params(batch)
280+
281+
# Unpadded should be used as-is since no -1 and min != -1
282+
expected_unpadded = torch.tensor([0, 5, 10], dtype=torch.int32)
283+
expected_padded = torch.tensor([0, 6, 12], dtype=torch.int32)
284+
285+
assert torch.equal(result.cu_seqlens_q, expected_unpadded)
286+
assert torch.equal(result.cu_seqlens_q_padded, expected_padded)
287+
184288

185289
class TestCreateLossFunction:
186290
"""Tests for the _create_loss_function helper function."""

0 commit comments

Comments
 (0)