Skip to content

Commit 8512cf2

Browse files
committed
add dit unit test.
Signed-off-by: Sajad Norouzi <snorouzi@nvidia.com>
1 parent 31d91c5 commit 8512cf2

11 files changed

+1624
-8
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import Mock, patch
16+
17+
from dfm.src.megatron.data.common.diffusion_energon_datamodule import (
18+
DiffusionDataModuleConfig,
19+
)
20+
21+
22+
def test_diffusion_data_module_config_initialization():
23+
"""Test DiffusionDataModuleConfig initialization and default values."""
24+
25+
# Mock the DiffusionDataModule to avoid actual dataset loading
26+
with patch("dfm.src.megatron.data.common.diffusion_energon_datamodule.DiffusionDataModule") as mock_data_module:
27+
# Setup the mock to return a mock dataset with seq_length attribute
28+
mock_dataset_instance = Mock()
29+
mock_dataset_instance.seq_length = 2048
30+
mock_data_module.return_value = mock_dataset_instance
31+
32+
# Create a DiffusionDataModuleConfig with required parameters
33+
config = DiffusionDataModuleConfig(
34+
path="/path/to/dataset",
35+
seq_length=2048,
36+
micro_batch_size=4,
37+
task_encoder_seq_length=512,
38+
packing_buffer_size=100,
39+
global_batch_size=32,
40+
num_workers=8,
41+
)
42+
43+
# Verify default values
44+
assert config.dataloader_type == "external", "Expected default dataloader_type to be 'external'"
45+
assert config.use_train_split_for_val is False, "Expected default use_train_split_for_val to be False"
46+
47+
# Verify required parameters are set correctly
48+
assert config.path == "/path/to/dataset"
49+
assert config.seq_length == 2048
50+
assert config.micro_batch_size == 4
51+
assert config.task_encoder_seq_length == 512
52+
assert config.packing_buffer_size == 100
53+
assert config.global_batch_size == 32
54+
assert config.num_workers == 8
55+
56+
# Verify that DiffusionDataModule was created in __post_init__
57+
assert mock_data_module.called, "DiffusionDataModule should be instantiated in __post_init__"
58+
59+
# Verify the dataset attribute was set
60+
assert config.dataset == mock_dataset_instance
61+
62+
# Verify sequence_length was set from the dataset
63+
assert config.sequence_length == 2048, "Expected sequence_length to be set from dataset.seq_length"
64+
65+
# Verify the DiffusionDataModule was created with correct parameters
66+
call_kwargs = mock_data_module.call_args.kwargs
67+
assert call_kwargs["path"] == "/path/to/dataset"
68+
assert call_kwargs["seq_length"] == 2048
69+
assert call_kwargs["micro_batch_size"] == 4
70+
assert call_kwargs["packing_buffer_size"] == 100
71+
assert call_kwargs["global_batch_size"] == 32
72+
assert call_kwargs["num_workers"] == 8
73+
assert call_kwargs["use_train_split_for_val"] is False
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample
18+
19+
20+
def test_add():
21+
"""Test __add__ method for DiffusionSample."""
22+
# Create two DiffusionSample instances with different seq_len_q
23+
sample1 = DiffusionSample(
24+
__key__="sample1",
25+
__restore_key__=(),
26+
__subflavor__=None,
27+
__subflavors__=["default"],
28+
video=torch.randn(3, 8, 16, 16),
29+
context_embeddings=torch.randn(10, 512),
30+
seq_len_q=torch.tensor(100),
31+
)
32+
sample2 = DiffusionSample(
33+
__key__="sample2",
34+
__restore_key__=(),
35+
__subflavor__=None,
36+
__subflavors__=["default"],
37+
video=torch.randn(3, 8, 16, 16),
38+
context_embeddings=torch.randn(10, 512),
39+
seq_len_q=torch.tensor(200),
40+
)
41+
42+
# Test adding two DiffusionSample instances
43+
result = sample1 + sample2
44+
assert result == 300, f"Expected 300, got {result}"
45+
46+
# Test adding DiffusionSample with an integer
47+
result = sample1 + 50
48+
assert result == 150, f"Expected 150, got {result}"
49+
50+
51+
def test_radd():
52+
"""Test __radd__ method for DiffusionSample."""
53+
# Create a DiffusionSample instance
54+
sample = DiffusionSample(
55+
__key__="sample",
56+
__restore_key__=(),
57+
__subflavor__=None,
58+
__subflavors__=["default"],
59+
video=torch.randn(3, 8, 16, 16),
60+
context_embeddings=torch.randn(10, 512),
61+
seq_len_q=torch.tensor(100),
62+
)
63+
64+
# Test reverse addition with an integer
65+
result = 50 + sample
66+
assert result == 150, f"Expected 150, got {result}"
67+
68+
# Test sum() function which uses __radd__ (starting with 0)
69+
samples = [
70+
DiffusionSample(
71+
__key__="sample1",
72+
__restore_key__=(),
73+
__subflavor__=None,
74+
__subflavors__=["default"],
75+
video=torch.randn(3, 8, 16, 16),
76+
context_embeddings=torch.randn(10, 512),
77+
seq_len_q=torch.tensor(10),
78+
),
79+
DiffusionSample(
80+
__key__="sample2",
81+
__restore_key__=(),
82+
__subflavor__=None,
83+
__subflavors__=["default"],
84+
video=torch.randn(3, 8, 16, 16),
85+
context_embeddings=torch.randn(10, 512),
86+
seq_len_q=torch.tensor(20),
87+
),
88+
DiffusionSample(
89+
__key__="sample3",
90+
__restore_key__=(),
91+
__subflavor__=None,
92+
__subflavors__=["default"],
93+
video=torch.randn(3, 8, 16, 16),
94+
context_embeddings=torch.randn(10, 512),
95+
seq_len_q=torch.tensor(30),
96+
),
97+
]
98+
result = sum(samples)
99+
assert result == 60, f"Expected 60, got {result}"
100+
101+
102+
def test_lt():
103+
"""Test __lt__ method for DiffusionSample."""
104+
# Create two DiffusionSample instances with different seq_len_q
105+
sample1 = DiffusionSample(
106+
__key__="sample1",
107+
__restore_key__=(),
108+
__subflavor__=None,
109+
__subflavors__=["default"],
110+
video=torch.randn(3, 8, 16, 16),
111+
context_embeddings=torch.randn(10, 512),
112+
seq_len_q=torch.tensor(100),
113+
)
114+
sample2 = DiffusionSample(
115+
__key__="sample2",
116+
__restore_key__=(),
117+
__subflavor__=None,
118+
__subflavors__=["default"],
119+
video=torch.randn(3, 8, 16, 16),
120+
context_embeddings=torch.randn(10, 512),
121+
seq_len_q=torch.tensor(200),
122+
)
123+
124+
# Test comparing two DiffusionSample instances
125+
assert sample1 < sample2, "Expected sample1 < sample2"
126+
assert not (sample2 < sample1), "Expected not (sample2 < sample1)"
127+
128+
# Test comparing DiffusionSample with an integer
129+
assert sample1 < 150, "Expected sample1 < 150"
130+
assert not (sample1 < 50), "Expected not (sample1 < 50)"
131+
132+
# Test sorting a list of DiffusionSample instances
133+
samples = [sample2, sample1]
134+
sorted_samples = sorted(samples)
135+
assert sorted_samples[0].seq_len_q.item() == 100, "Expected first element to have seq_len_q=100"
136+
assert sorted_samples[1].seq_len_q.item() == 200, "Expected second element to have seq_len_q=200"
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
17+
import torch
18+
19+
from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample
20+
from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking
21+
22+
23+
class ConcreteDiffusionTaskEncoder(DiffusionTaskEncoderWithSequencePacking):
24+
"""Concrete implementation for testing."""
25+
26+
def encode_sample(self, sample: dict) -> dict:
27+
"""Simple implementation for testing purposes."""
28+
return sample
29+
30+
def batch(self, samples: List[DiffusionSample]) -> dict:
31+
"""Simple batch implementation that returns first sample as dict."""
32+
if len(samples) == 1:
33+
sample = samples[0]
34+
return dict(
35+
video=sample.video.unsqueeze(0),
36+
context_embeddings=sample.context_embeddings.unsqueeze(0),
37+
context_mask=sample.context_mask.unsqueeze(0) if sample.context_mask is not None else None,
38+
loss_mask=sample.loss_mask.unsqueeze(0) if sample.loss_mask is not None else None,
39+
seq_len_q=sample.seq_len_q,
40+
seq_len_q_padded=sample.seq_len_q_padded,
41+
seq_len_kv=sample.seq_len_kv,
42+
seq_len_kv_padded=sample.seq_len_kv_padded,
43+
pos_ids=sample.pos_ids.unsqueeze(0) if sample.pos_ids is not None else None,
44+
latent_shape=sample.latent_shape,
45+
video_metadata=sample.video_metadata,
46+
)
47+
else:
48+
# For multiple samples, just return a simple dict
49+
return {"samples": samples}
50+
51+
52+
def create_diffusion_sample(key: str, seq_len: int, video_shape=(16, 8), embedding_dim=128) -> DiffusionSample:
53+
"""Helper function to create a DiffusionSample for testing."""
54+
return DiffusionSample(
55+
__key__=key,
56+
__restore_key__=(),
57+
__subflavor__=None,
58+
__subflavors__=["default"],
59+
video=torch.randn(seq_len, video_shape[0]),
60+
context_embeddings=torch.randn(10, embedding_dim),
61+
context_mask=torch.ones(10),
62+
loss_mask=torch.ones(seq_len),
63+
seq_len_q=torch.tensor([seq_len], dtype=torch.int32),
64+
seq_len_q_padded=torch.tensor([seq_len], dtype=torch.int32),
65+
seq_len_kv=torch.tensor([10], dtype=torch.int32),
66+
seq_len_kv_padded=torch.tensor([10], dtype=torch.int32),
67+
pos_ids=torch.arange(seq_len).unsqueeze(1),
68+
latent_shape=torch.tensor([4, 2, 4, 4], dtype=torch.int32),
69+
video_metadata={"fps": 30, "resolution": "512x512"},
70+
)
71+
72+
73+
def test_select_samples_to_pack():
74+
"""Test select_samples_to_pack method."""
75+
# Create encoder with seq_length=20
76+
encoder = ConcreteDiffusionTaskEncoder(seq_length=20)
77+
78+
# Create samples with different sequence lengths
79+
samples = [
80+
create_diffusion_sample("sample_1", seq_len=8),
81+
create_diffusion_sample("sample_2", seq_len=12),
82+
create_diffusion_sample("sample_3", seq_len=5),
83+
create_diffusion_sample("sample_4", seq_len=7),
84+
create_diffusion_sample("sample_5", seq_len=3),
85+
]
86+
87+
# Call select_samples_to_pack
88+
result = encoder.select_samples_to_pack(samples)
89+
90+
# Verify result is a list of lists
91+
assert isinstance(result, list), "Result should be a list"
92+
assert all(isinstance(group, list) for group in result), "All elements should be lists"
93+
94+
# Verify all samples are included
95+
all_samples = [sample for group in result for sample in group]
96+
assert len(all_samples) == len(samples), "All samples should be included"
97+
98+
# Verify no bin exceeds seq_length
99+
for group in result:
100+
total_seq_len = sum(sample.seq_len_q.item() for sample in group)
101+
assert total_seq_len <= encoder.seq_length, (
102+
f"Bin with total {total_seq_len} exceeds seq_length {encoder.seq_length}"
103+
)
104+
105+
# Verify that bins are non-empty
106+
assert all(len(group) > 0 for group in result), "No bin should be empty"
107+
108+
print(f"✓ Successfully packed {len(samples)} samples into {len(result)} bins")
109+
print(f" Bin sizes: {[sum(s.seq_len_q.item() for s in group) for group in result]}")
110+
111+
112+
def test_pack_selected_samples():
113+
"""Test pack_selected_samples method."""
114+
encoder = ConcreteDiffusionTaskEncoder(seq_length=100)
115+
116+
# Create multiple samples to pack
117+
sample_1_length = 10
118+
sample_2_length = 15
119+
sample_3_length = 8
120+
sample_1 = create_diffusion_sample("sample_1", seq_len=sample_1_length)
121+
sample_2 = create_diffusion_sample("sample_2", seq_len=sample_2_length)
122+
sample_3 = create_diffusion_sample("sample_3", seq_len=sample_3_length)
123+
124+
samples_to_pack = [sample_1, sample_2, sample_3]
125+
126+
# Pack the samples
127+
packed_sample = encoder.pack_selected_samples(samples_to_pack)
128+
129+
# Verify the packed sample is a DiffusionSample
130+
assert isinstance(packed_sample, DiffusionSample), "Result should be a DiffusionSample"
131+
132+
# Verify __key__ is concatenated
133+
expected_key = "sample_1,sample_2,sample_3"
134+
assert packed_sample.__key__ == expected_key, f"Key should be '{expected_key}'"
135+
136+
# Verify video is concatenated along dim 0
137+
expected_video_len = 10 + 15 + 8
138+
assert packed_sample.video.shape[0] == expected_video_len, f"Video should have length {expected_video_len}"
139+
140+
# Verify context_embeddings is concatenated
141+
expected_context_len = 10 * 3 # 3 samples with 10 embeddings each
142+
assert packed_sample.context_embeddings.shape[0] == expected_context_len, (
143+
f"Context embeddings should have length {expected_context_len}"
144+
)
145+
146+
# Verify context_mask is concatenated
147+
assert packed_sample.context_mask.shape[0] == expected_context_len, (
148+
f"Context mask should have length {expected_context_len}"
149+
)
150+
151+
# Verify loss_mask is concatenated
152+
assert packed_sample.loss_mask.shape[0] == expected_video_len, f"Loss mask should have length {expected_video_len}"
153+
154+
# Verify seq_len_q is concatenated
155+
assert packed_sample.seq_len_q.shape[0] == 3, "seq_len_q should have 3 elements"
156+
assert torch.equal(
157+
packed_sample.seq_len_q, torch.tensor([sample_1_length, sample_2_length, sample_3_length], dtype=torch.int32)
158+
), "seq_len_q values incorrect"
159+
160+
assert packed_sample.seq_len_q_padded.shape[0] == 3, "seq_len_q_padded should have 3 elements"
161+
assert torch.equal(
162+
packed_sample.seq_len_q_padded,
163+
torch.tensor([sample_1_length, sample_2_length, sample_3_length], dtype=torch.int32),
164+
), "seq_len_q_padded values incorrect"
165+
166+
assert packed_sample.seq_len_kv.shape[0] == 3, "seq_len_kv should have 3 elements"
167+
assert torch.equal(packed_sample.seq_len_kv, torch.tensor([10, 10, 10], dtype=torch.int32)), (
168+
"seq_len_kv values incorrect"
169+
)
170+
171+
assert packed_sample.seq_len_kv_padded.shape[0] == 3, "seq_len_kv_padded should have 3 elements"
172+
assert torch.equal(packed_sample.seq_len_kv_padded, torch.tensor([10, 10, 10], dtype=torch.int32)), (
173+
"seq_len_kv_padded values incorrect"
174+
)
175+
176+
assert packed_sample.latent_shape.shape[0] == 3, "latent_shape should have 3 rows"
177+
assert isinstance(packed_sample.video_metadata, list), "video_metadata should be a list"
178+
assert len(packed_sample.video_metadata) == 3, "video_metadata should have 3 elements"
179+
180+
print(f"✓ Successfully packed {len(samples_to_pack)} samples")
181+
print(f" Packed video shape: {packed_sample.video.shape}")
182+
print(f" Packed context embeddings shape: {packed_sample.context_embeddings.shape}")

0 commit comments

Comments
 (0)