88
99from fvdb import GridBatch , JaggedTensor
1010from fvdb .utils .tests import dtype_to_atol
11- from fvdb .utils .tests .grid_utils import create_test_grid_batch
11+ from fvdb .utils .tests .grid_utils import make_grid_batch_and_jagged_point_data
1212
1313all_device_dtype_combos = [
1414 ["cuda" , torch .float16 ],
@@ -28,15 +28,9 @@ def setUp(self):
2828 def test_morton_permutation (self , device , dtype ):
2929 """Test Morton order permutations (xyz and zyx variants)."""
3030 # Create a test grid batch with known active voxels
31- batch_size = 2
32- grid_size = 8
33-
3431 # Create a grid batch with some active voxels
35- grid_batch = create_test_grid_batch (
36- batch_size = batch_size ,
37- grid_size = grid_size ,
38- device = device ,
39- density = 0.3 # 30% of voxels will be active
32+ grid_batch , _ , _ = make_grid_batch_and_jagged_point_data (
33+ device = device , dtype = dtype , include_boundary_points = True
4034 )
4135
4236 # Get permutation indices for both Morton orderings
@@ -51,17 +45,17 @@ def test_morton_permutation(self, device, dtype):
5145 continue
5246
5347 # Extract permutation indices for this grid
54- grid_perm = morton_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
55- grid_perm_zyx = morton_zyx_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
48+ grid_perm = morton_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
49+ grid_perm_zyx = morton_zyx_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
5650
5751 # Verify permutations contain all indices
5852 expected_indices = torch .arange (offset , offset + num_voxels , device = device )
5953 self .assertTrue (torch .sort (grid_perm )[0 ].equal (expected_indices ))
6054 self .assertTrue (torch .sort (grid_perm_zyx )[0 ].equal (expected_indices ))
6155
6256 # Get Morton codes
63- morton_codes = grid_batch .encode_morton ().jdata [offset : offset + num_voxels ]
64- morton_zyx_codes = grid_batch .encode_morton_zyx ().jdata [offset : offset + num_voxels ]
57+ morton_codes = grid_batch .encode_morton ().jdata [offset : offset + num_voxels ]
58+ morton_zyx_codes = grid_batch .encode_morton_zyx ().jdata [offset : offset + num_voxels ]
6559
6660 # Verify codes are sorted after applying permutation
6761 sorted_codes = morton_codes [grid_perm - offset ]
@@ -72,20 +66,13 @@ def test_morton_permutation(self, device, dtype):
7266
7367 offset += num_voxels
7468
75-
7669 @parameterized .expand (all_device_dtype_combos )
7770 def test_hilbert_permutation (self , device , dtype ):
7871 """Test Hilbert curve permutations (xyz and zyx variants)."""
7972 # Create a test grid batch with known active voxels
80- batch_size = 2
81- grid_size = 8
82-
8373 # Create a grid batch with some active voxels
84- grid_batch = create_test_grid_batch (
85- batch_size = batch_size ,
86- grid_size = grid_size ,
87- device = device ,
88- density = 0.3 # 30% of voxels will be active
74+ grid_batch , _ , _ = make_grid_batch_and_jagged_point_data (
75+ device = device , dtype = dtype , include_boundary_points = True
8976 )
9077
9178 # Get permutation indices for both Hilbert orderings
@@ -100,17 +87,17 @@ def test_hilbert_permutation(self, device, dtype):
10087 continue
10188
10289 # Extract permutation indices for this grid
103- grid_perm = hilbert_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
104- grid_perm_zyx = hilbert_zyx_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
90+ grid_perm = hilbert_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
91+ grid_perm_zyx = hilbert_zyx_perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
10592
10693 # Verify permutations contain all indices
10794 expected_indices = torch .arange (offset , offset + num_voxels , device = device )
10895 self .assertTrue (torch .sort (grid_perm )[0 ].equal (expected_indices ))
10996 self .assertTrue (torch .sort (grid_perm_zyx )[0 ].equal (expected_indices ))
11097
11198 # Get Hilbert codes
112- hilbert_codes = grid_batch .encode_hilbert ().jdata [offset : offset + num_voxels ]
113- hilbert_zyx_codes = grid_batch .encode_hilbert_zyx ().jdata [offset : offset + num_voxels ]
99+ hilbert_codes = grid_batch .encode_hilbert ().jdata [offset : offset + num_voxels ]
100+ hilbert_zyx_codes = grid_batch .encode_hilbert_zyx ().jdata [offset : offset + num_voxels ]
114101
115102 # Verify codes are sorted after applying permutation
116103 sorted_codes = hilbert_codes [grid_perm - offset ]
@@ -121,29 +108,21 @@ def test_hilbert_permutation(self, device, dtype):
121108
122109 offset += num_voxels
123110
124-
125111 @parameterized .expand (all_device_dtype_combos )
126112 def test_permutation_validity (self , device , dtype ):
127113 """Test that permutation indices are valid (complete and unique)."""
128- batch_size = 3
129- grid_size = 16
130-
131- # Create grid batch with varying densities
132- densities = [0.1 , 0.5 , 0.8 ] # Test different sparsity levels
133- for density in densities :
134- grid_batch = create_test_grid_batch (
135- batch_size = batch_size ,
136- grid_size = grid_size ,
137- device = device ,
138- density = density
114+ # Create multiple grid batches to test different configurations
115+ for _ in range (3 ):
116+ grid_batch , _ , _ = make_grid_batch_and_jagged_point_data (
117+ device = device , dtype = dtype , include_boundary_points = True
139118 )
140119
141120 # Test all permutation types
142121 permutations = [
143122 grid_batch .permutation_morton (),
144123 grid_batch .permutation_morton_zyx (),
145124 grid_batch .permutation_hilbert (),
146- grid_batch .permutation_hilbert_zyx ()
125+ grid_batch .permutation_hilbert_zyx (),
147126 ]
148127
149128 for perm in permutations :
@@ -154,7 +133,7 @@ def test_permutation_validity(self, device, dtype):
154133 continue
155134
156135 # Extract permutation indices for this grid
157- grid_perm = perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
136+ grid_perm = perm .jdata [offset : offset + num_voxels ].squeeze (- 1 )
158137
159138 # Verify indices are within valid range
160139 self .assertTrue (torch .all (grid_perm >= offset ))
@@ -169,5 +148,6 @@ def test_permutation_validity(self, device, dtype):
169148
170149 offset += num_voxels
171150
151+
172152if __name__ == "__main__" :
173153 unittest .main ()
0 commit comments