Skip to content

Commit 7a1186d

Browse files
committed
add tests
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 5631753 commit 7a1186d

File tree

4 files changed

+167
-15
lines changed

4 files changed

+167
-15
lines changed

tests/unit_tests/_transformers/test_auto_model.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,19 @@
1414

1515
import logging
1616
import types
17-
from unittest.mock import MagicMock, Mock, patch
17+
from unittest.mock import MagicMock, patch
1818

1919
import pytest
2020
import torch
2121

2222
from nemo_automodel._transformers.auto_model import (
23+
_consume_config_overrides,
2324
_get_next_fallback_attn,
2425
_init_model,
2526
_patch_attention,
26-
_consume_config_overrides,
2727
)
2828
from nemo_automodel._transformers.infrastructure import _apply_peft_and_lower_precision
29-
from nemo_automodel._transformers.model_init import _filter_kwargs_for_init
30-
from nemo_automodel._transformers.model_init import _get_mixin_wrapped_class
29+
from nemo_automodel._transformers.model_init import _filter_kwargs_for_init, _get_mixin_wrapped_class
3130
from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin
3231

3332

@@ -514,6 +513,7 @@ class TestNeedSetupCacheClassesMapping:
514513
def test_shim_does_not_overwrite_existing_attribute(self):
515514
"""If NEED_SETUP_CACHE_CLASSES_MAPPING already exists, shim doesn't overwrite."""
516515
import importlib
516+
517517
import transformers.generation.utils as gen_utils
518518

519519
sentinel = {"test": "sentinel_value"}
@@ -532,6 +532,7 @@ def test_shim_does_not_overwrite_existing_attribute(self):
532532
def test_shim_creates_attribute_when_missing(self):
533533
"""If NEED_SETUP_CACHE_CLASSES_MAPPING is missing, shim creates it."""
534534
import importlib
535+
535536
import transformers.generation.utils as gen_utils
536537

537538
# Remove the attribute if it exists
@@ -699,3 +700,41 @@ def __init__(self):
699700

700701
assert is_custom is False
701702
mock_wrap.assert_called_once_with(FakeModel)
703+
704+
705+
class TestNeMoAutoModelForMultimodalLM:
706+
"""Tests for the NeMoAutoModelForMultimodalLM class and its exports."""
707+
708+
def test_class_exists_and_inherits_correctly(self):
709+
from transformers import AutoModelForMultimodalLM
710+
711+
from nemo_automodel._transformers.auto_model import NeMoAutoModelForMultimodalLM, _BaseNeMoAutoModelClass
712+
713+
assert issubclass(NeMoAutoModelForMultimodalLM, _BaseNeMoAutoModelClass)
714+
assert issubclass(NeMoAutoModelForMultimodalLM, AutoModelForMultimodalLM)
715+
716+
def test_has_from_pretrained_and_from_config(self):
717+
from nemo_automodel._transformers.auto_model import NeMoAutoModelForMultimodalLM
718+
719+
assert callable(NeMoAutoModelForMultimodalLM.from_pretrained)
720+
assert callable(NeMoAutoModelForMultimodalLM.from_config)
721+
722+
def test_lazy_export_from_transformers_subpackage(self):
723+
from nemo_automodel._transformers import NeMoAutoModelForMultimodalLM
724+
725+
assert NeMoAutoModelForMultimodalLM is not None
726+
727+
def test_lazy_export_from_top_level_package(self):
728+
from nemo_automodel import NeMoAutoModelForMultimodalLM
729+
730+
assert NeMoAutoModelForMultimodalLM is not None
731+
732+
def test_top_level_dir_includes_multimodal(self):
733+
import nemo_automodel
734+
735+
assert "NeMoAutoModelForMultimodalLM" in dir(nemo_automodel)
736+
737+
def test_transformers_subpackage_all_includes_multimodal(self):
738+
import nemo_automodel._transformers as pkg
739+
740+
assert "NeMoAutoModelForMultimodalLM" in pkg.__all__

tests/unit_tests/_transformers/test_transformers_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818

19-
from nemo_automodel._transformers.utils import sliding_window_overwrite
19+
from nemo_automodel._transformers.utils import apply_qwen3_omni_config_patch, sliding_window_overwrite
2020

2121

2222
class TestSlidingWindowOverwrite:
@@ -165,3 +165,43 @@ def test_sliding_window_overwrite_hasattr_behavior(self, mock_from_pretrained):
165165

166166
# Should return empty dict when use_sliding_window is not exactly False
167167
assert result == {}
168+
169+
170+
class TestApplyQwen3OmniConfigPatch:
171+
"""Test cases for apply_qwen3_omni_config_patch function."""
172+
173+
def test_patch_sets_use_sliding_window_default(self):
174+
"""Verify the patch adds use_sliding_window=False to the config class."""
175+
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
176+
Qwen3OmniMoeTalkerCodePredictorConfig,
177+
)
178+
179+
apply_qwen3_omni_config_patch()
180+
assert hasattr(Qwen3OmniMoeTalkerCodePredictorConfig, "use_sliding_window")
181+
182+
def test_patch_is_idempotent(self):
183+
"""Calling the patch twice does not raise or change the value."""
184+
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
185+
Qwen3OmniMoeTalkerCodePredictorConfig,
186+
)
187+
188+
apply_qwen3_omni_config_patch()
189+
apply_qwen3_omni_config_patch()
190+
assert Qwen3OmniMoeTalkerCodePredictorConfig.use_sliding_window is False
191+
192+
def test_patch_does_not_overwrite_existing_attribute(self):
193+
"""If the attribute already exists (e.g. fixed upstream), patch is a no-op."""
194+
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
195+
Qwen3OmniMoeTalkerCodePredictorConfig,
196+
)
197+
198+
original = getattr(Qwen3OmniMoeTalkerCodePredictorConfig, "use_sliding_window", None)
199+
Qwen3OmniMoeTalkerCodePredictorConfig.use_sliding_window = True
200+
try:
201+
apply_qwen3_omni_config_patch()
202+
assert Qwen3OmniMoeTalkerCodePredictorConfig.use_sliding_window is True
203+
finally:
204+
if original is None:
205+
del Qwen3OmniMoeTalkerCodePredictorConfig.use_sliding_window
206+
else:
207+
Qwen3OmniMoeTalkerCodePredictorConfig.use_sliding_window = original

tests/unit_tests/models/qwen3_omni_moe/test_qwen3_omni_moe_model.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@
2525
Qwen3OmniMoeThinkerForConditionalGeneration as HFQwen3OmniMoeThinkerForConditionalGeneration,
2626
)
2727

28+
from nemo_automodel.components.models.common import BackendConfig
2829
from nemo_automodel.components.models.qwen3_omni_moe.model import (
2930
Qwen3OmniMoeThinkerForConditionalGeneration,
3031
Qwen3OmniMoeThinkerTextModel,
3132
)
3233
from nemo_automodel.components.moe.config import MoEConfig
33-
from nemo_automodel.components.models.common import BackendConfig
34-
3534

3635
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
3736

@@ -259,3 +258,47 @@ def test_modelclass_export_exists():
259258

260259
assert hasattr(omni_module, "ModelClass")
261260
assert omni_module.ModelClass is Qwen3OmniMoeThinkerForConditionalGeneration
261+
262+
263+
@patch.object(HFQwen3OmniMoeThinkerForConditionalGeneration, "__init__", new=_stub_hf_init)
264+
@patch("nemo_automodel.components.models.qwen3_omni_moe.model.Qwen3OmniMoeThinkerTextRotaryEmbedding")
265+
def test_forward_unpacks_vision_features_from_named_output(rotary_cls, thinker_config, backend_config, moe_config, device):
266+
"""get_image_features / get_video_features return a named output object;
267+
forward() must use .pooler_output and .deepstack_features, not tuple unpacking."""
268+
rotary_cls.return_value = MagicMock(side_effect=lambda x, y: (torch.zeros_like(x), torch.zeros_like(x)))
269+
model = Qwen3OmniMoeThinkerForConditionalGeneration(thinker_config, moe_config=moe_config, backend=backend_config).to(device)
270+
model.config = thinker_config
271+
272+
hidden_size = thinker_config.text_config.hidden_size
273+
vocab_size = thinker_config.text_config.vocab_size
274+
batch, seq_len = 1, 6
275+
276+
fake_image_embed = torch.randn(2, hidden_size, device=device)
277+
fake_deepstack = [torch.randn(2, hidden_size, device=device)]
278+
279+
fake_vision_output = SimpleNamespace(
280+
last_hidden_state=torch.randn(2, hidden_size, device=device),
281+
pooler_output=fake_image_embed,
282+
hidden_states=None,
283+
attentions=None,
284+
deepstack_features=fake_deepstack,
285+
)
286+
287+
hidden = torch.randn(batch, seq_len, hidden_size, device=device, dtype=model.lm_head.weight.dtype)
288+
input_ids = torch.randint(0, vocab_size, (batch, seq_len), device=device)
289+
290+
with (
291+
patch.object(model.model, "forward", return_value=hidden),
292+
patch.object(model, "get_image_features", return_value=fake_vision_output) as mock_gif,
293+
patch.object(model, "get_placeholder_mask", return_value=(
294+
torch.zeros(batch, seq_len, 1, dtype=torch.bool, device=device),
295+
torch.zeros(batch, seq_len, 1, dtype=torch.bool, device=device),
296+
torch.zeros(batch, seq_len, 1, dtype=torch.bool, device=device),
297+
)),
298+
):
299+
pixel_values = torch.randn(1, 3, 224, 224, device=device)
300+
image_grid_thw = torch.tensor([[1, 14, 14]], device=device)
301+
logits = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
302+
303+
mock_gif.assert_called_once()
304+
assert logits.shape == (batch, seq_len, vocab_size)

tests/unit_tests/recipes/test_finetune_vlm_helpers.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from contextlib import nullcontext
15+
from types import SimpleNamespace
16+
from unittest.mock import MagicMock, patch
17+
1418
import pytest
1519
import torch
1620
import torch.nn as nn
17-
from unittest.mock import patch, MagicMock
18-
from types import SimpleNamespace
19-
20-
from contextlib import nullcontext
2121

2222
from nemo_automodel.components.loggers.metric_logger import MetricsSample
2323
from nemo_automodel.recipes.vlm.finetune import (
@@ -41,7 +41,6 @@ def test_get_model_name_prefers_pretrained_path():
4141
assert _get_model_name(_Cfg()) is None
4242

4343

44-
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
4544

4645

4746

@@ -401,6 +400,7 @@ def test_autoprocessor_success():
401400
def test_autoprocessor_exception_handling(caplog):
402401
"""Test AutoProcessor exception handling and logging in build_dataloader."""
403402
import logging
403+
404404
from nemo_automodel.recipes.vlm.finetune import build_dataloader
405405

406406
with patch('transformers.AutoProcessor.from_pretrained') as mock_from_pretrained, \
@@ -434,6 +434,7 @@ def test_autoprocessor_exception_handling(caplog):
434434
def test_autoprocessor_with_processor_kwargs(caplog):
435435
"""Test AutoProcessor exception handling when cfg_processor has no instantiate method."""
436436
import logging
437+
437438
from nemo_automodel.recipes.vlm.finetune import build_dataloader
438439

439440
# Simple processor config class without instantiate method
@@ -707,13 +708,12 @@ def get(self, key, default=None):
707708

708709

709710
from nemo_automodel.recipes.vlm.finetune import (
710-
build_step_scheduler,
711-
build_lr_scheduler,
712711
build_checkpoint_config,
712+
build_lr_scheduler,
713+
build_step_scheduler,
713714
calculate_loss,
714715
)
715716

716-
717717
# -----------------------------------------------------------------------------
718718
# build_step_scheduler tests
719719
# -----------------------------------------------------------------------------
@@ -1847,3 +1847,33 @@ def get(self, key, default=None):
18471847
)
18481848

18491849
assert model is not None
1850+
1851+
1852+
@pytest.mark.parametrize("entry_point", ["from_config", "from_pretrained"])
1853+
def test_vlm_build_model_accepts_multimodal_lm_entry_points(entry_point):
1854+
"""Test that VLM build_model accepts NeMoAutoModelForMultimodalLM entry points."""
1855+
from nemo_automodel._transformers import NeMoAutoModelForMultimodalLM
1856+
1857+
target = getattr(NeMoAutoModelForMultimodalLM, entry_point)
1858+
1859+
class NeMoVLMModelConfig:
1860+
def __init__(self):
1861+
self._target_ = target
1862+
1863+
def instantiate(self, **kwargs):
1864+
return DummyModel()
1865+
1866+
def get(self, key, default=None):
1867+
return getattr(self, key, default)
1868+
1869+
cfg_model = NeMoVLMModelConfig()
1870+
1871+
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
1872+
model = build_model(
1873+
cfg_model=cfg_model,
1874+
cfg_freeze=None,
1875+
cfg_peft=None,
1876+
seed=42,
1877+
)
1878+
1879+
assert model is not None

0 commit comments

Comments
 (0)