Skip to content

Commit b5af92d

Browse files
committed
Test cleanup and formatting issues addressed
Signed-off-by: Christopher Horvath <chorvath@nvidia.com>
1 parent 2834263 commit b5af92d

File tree

3 files changed

+121
-81
lines changed

3 files changed

+121
-81
lines changed

point_transformer_v3/tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
import sys
1414
from unittest.mock import MagicMock
1515

16+
1617
# Helper to mock a module and its submodules
1718
def mock_module(module_name):
1819
if module_name not in sys.modules:
1920
m = MagicMock()
20-
m.__path__ = [] # Make it look like a package
21+
m.__path__ = [] # Make it look like a package
2122
sys.modules[module_name] = m
2223
return sys.modules[module_name]
2324

25+
2426
# Mock pointops to avoid installing custom CUDA extensions for unit tests
2527
mock_module("pointops")
2628
mock_module("pointgroup_ops")
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
# Copyright Contributors to the OpenVDB Project
22
# SPDX-License-Identifier: Apache-2.0
3-

point_transformer_v3/tests/unit/test_order_type_validation.py

Lines changed: 118 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,90 +2,129 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""
5-
Unit tests for order type validation in PTV3_Attention.
5+
Unit tests for space-filling curve functions in fvdb_utils.
66
7-
This is a minimal unit test to establish the test infrastructure.
8-
Tests a pure function with no dependencies on FVDB operations or external libraries.
7+
Tests the standalone curve type validation and dispatch logic
8+
in space_filling_curve_from_jagged_ijk and related functions.
99
"""
1010

1111
import unittest
12-
13-
from fvdb_extensions.models.ptv3_fvdb import PTV3_Attention
14-
15-
16-
class TestOrderTypeValidation(unittest.TestCase):
17-
"""Test cases for order type validation in PTV3_Attention."""
18-
19-
def setUp(self):
20-
"""Set up test fixtures."""
21-
pass
22-
23-
def test_permute_valid_order_types(self):
24-
"""Test that _permute accepts valid order types."""
25-
# Create a minimal attention instance (we won't call forward, just test _permute)
26-
attn = PTV3_Attention(
27-
hidden_size=64,
28-
num_heads=1,
29-
proj_drop=0.0,
30-
patch_size=0,
31-
order_index=0,
32-
order_types=("z", "z-trans"),
33-
)
34-
35-
# Test that valid order types don't raise errors
36-
# Note: We can't actually call _permute without a grid, but we can test the validation logic
37-
valid_types = ["z", "z-trans", "hilbert", "hilbert-trans"]
38-
for order_type in valid_types:
39-
# The _permute method should accept these without raising ValueError
40-
# We test this by checking the method exists and has the right signature
41-
self.assertTrue(hasattr(attn, "_permute"), "Should have _permute method")
42-
self.assertTrue(callable(getattr(attn, "_permute")), "_permute should be callable")
43-
44-
def test_permute_invalid_order_type_raises(self):
45-
"""Test that _permute raises ValueError for invalid order types."""
46-
attn = PTV3_Attention(
47-
hidden_size=64,
48-
num_heads=1,
49-
proj_drop=0.0,
50-
patch_size=0,
51-
order_index=0,
52-
order_types=("z",),
53-
)
54-
55-
# Create a mock grid object with minimal interface needed for _permute
56-
class MockGrid:
57-
def morton(self):
58-
return None
59-
60-
def morton_zyx(self):
61-
return None
62-
63-
def hilbert(self):
64-
return None
65-
66-
def hilbert_zyx(self):
67-
return None
68-
69-
mock_grid = MockGrid()
70-
71-
# Test that invalid order type raises ValueError
72-
with self.assertRaises(ValueError):
73-
attn._permute(mock_grid, "invalid_order_type")
74-
75-
def test_order_type_initialization(self):
76-
"""Test that order_types are correctly stored during initialization."""
77-
order_types = ("z", "z-trans", "hilbert")
78-
attn = PTV3_Attention(
79-
hidden_size=64,
80-
num_heads=1,
81-
proj_drop=0.0,
82-
patch_size=0,
83-
order_index=0,
84-
order_types=order_types,
85-
)
86-
self.assertEqual(attn.order_types, order_types, "Order types should be stored correctly")
12+
from unittest.mock import MagicMock, patch
13+
14+
from fvdb_extensions.models.fvdb_utils import (
15+
hilbert_flipped_from_jagged_ijk,
16+
hilbert_from_jagged_ijk,
17+
identity_from_jagged_ijk,
18+
morton_flipped_from_jagged_ijk,
19+
morton_from_jagged_ijk,
20+
space_filling_curve_from_jagged_ijk,
21+
)
22+
23+
24+
class TestSpaceFillingCurveValidation(unittest.TestCase):
25+
"""Test cases for curve type validation in space_filling_curve_from_jagged_ijk."""
26+
27+
def test_invalid_curve_type_raises_value_error(self):
28+
"""Test that invalid curve types raise ValueError."""
29+
mock_jagged_ijk = MagicMock()
30+
31+
invalid_types = ["invalid", "unknown", "random", "xyz", ""]
32+
for curve_type in invalid_types:
33+
with self.assertRaises(ValueError, msg=f"Should raise ValueError for '{curve_type}'"):
34+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, curve_type)
35+
36+
@patch("fvdb_extensions.models.fvdb_utils.morton_from_jagged_ijk")
37+
def test_morton_alias_z(self, mock_morton):
38+
"""Test that 'z' dispatches to morton_from_jagged_ijk."""
39+
mock_jagged_ijk = MagicMock()
40+
mock_morton.return_value = MagicMock()
41+
42+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "z")
43+
mock_morton.assert_called_once_with(mock_jagged_ijk)
44+
45+
@patch("fvdb_extensions.models.fvdb_utils.morton_from_jagged_ijk")
46+
def test_morton_alias_morton(self, mock_morton):
47+
"""Test that 'morton' dispatches to morton_from_jagged_ijk."""
48+
mock_jagged_ijk = MagicMock()
49+
mock_morton.return_value = MagicMock()
50+
51+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "morton")
52+
mock_morton.assert_called_once_with(mock_jagged_ijk)
53+
54+
@patch("fvdb_extensions.models.fvdb_utils.morton_flipped_from_jagged_ijk")
55+
def test_morton_flipped_alias_z_trans(self, mock_morton_flipped):
56+
"""Test that 'z-trans' dispatches to morton_flipped_from_jagged_ijk."""
57+
mock_jagged_ijk = MagicMock()
58+
mock_morton_flipped.return_value = MagicMock()
59+
60+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "z-trans")
61+
mock_morton_flipped.assert_called_once_with(mock_jagged_ijk)
62+
63+
@patch("fvdb_extensions.models.fvdb_utils.morton_flipped_from_jagged_ijk")
64+
def test_morton_flipped_alias_morton_zyx(self, mock_morton_flipped):
65+
"""Test that 'morton_zyx' dispatches to morton_flipped_from_jagged_ijk."""
66+
mock_jagged_ijk = MagicMock()
67+
mock_morton_flipped.return_value = MagicMock()
68+
69+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "morton_zyx")
70+
mock_morton_flipped.assert_called_once_with(mock_jagged_ijk)
71+
72+
@patch("fvdb_extensions.models.fvdb_utils.hilbert_from_jagged_ijk")
73+
def test_hilbert_dispatch(self, mock_hilbert):
74+
"""Test that 'hilbert' dispatches to hilbert_from_jagged_ijk."""
75+
mock_jagged_ijk = MagicMock()
76+
mock_hilbert.return_value = MagicMock()
77+
78+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "hilbert")
79+
mock_hilbert.assert_called_once_with(mock_jagged_ijk)
80+
81+
@patch("fvdb_extensions.models.fvdb_utils.hilbert_flipped_from_jagged_ijk")
82+
def test_hilbert_trans_dispatch(self, mock_hilbert_flipped):
83+
"""Test that 'hilbert-trans' dispatches to hilbert_flipped_from_jagged_ijk."""
84+
mock_jagged_ijk = MagicMock()
85+
mock_hilbert_flipped.return_value = MagicMock()
86+
87+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "hilbert-trans")
88+
mock_hilbert_flipped.assert_called_once_with(mock_jagged_ijk)
89+
90+
@patch("fvdb_extensions.models.fvdb_utils.identity_from_jagged_ijk")
91+
def test_identity_alias_vdb(self, mock_identity):
92+
"""Test that 'vdb' dispatches to identity_from_jagged_ijk."""
93+
mock_jagged_ijk = MagicMock()
94+
mock_identity.return_value = MagicMock()
95+
96+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "vdb")
97+
mock_identity.assert_called_once_with(mock_jagged_ijk)
98+
99+
@patch("fvdb_extensions.models.fvdb_utils.identity_from_jagged_ijk")
100+
def test_identity_alias_identity(self, mock_identity):
101+
"""Test that 'identity' dispatches to identity_from_jagged_ijk."""
102+
mock_jagged_ijk = MagicMock()
103+
mock_identity.return_value = MagicMock()
104+
105+
space_filling_curve_from_jagged_ijk(mock_jagged_ijk, "identity")
106+
mock_identity.assert_called_once_with(mock_jagged_ijk)
107+
108+
109+
class TestValidCurveTypes(unittest.TestCase):
110+
"""Test that all documented curve types are valid."""
111+
112+
def test_all_valid_curve_types_documented(self):
113+
"""Verify the complete set of valid curve types."""
114+
valid_types = {
115+
"z",
116+
"morton",
117+
"z-trans",
118+
"morton_zyx",
119+
"hilbert",
120+
"hilbert-trans",
121+
"vdb",
122+
"identity",
123+
}
124+
# This test documents the expected valid types
125+
# The dispatch tests above verify each one works
126+
self.assertEqual(len(valid_types), 8, "Should have 8 valid curve type aliases")
87127

88128

89129
if __name__ == "__main__":
90130
unittest.main()
91-

0 commit comments

Comments
 (0)