Skip to content

Commit 768a060

Browse files
author
Hexu Zhao
committed
format
Signed-off-by: Hexu Zhao <hexuz@nvidia.com>
1 parent bd7441d commit 768a060

File tree

1 file changed

+20
-40
lines changed

1 file changed

+20
-40
lines changed

tests/unit/test_serialization.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fvdb import GridBatch, JaggedTensor
1010
from fvdb.utils.tests import dtype_to_atol
11-
from fvdb.utils.tests.grid_utils import create_test_grid_batch
11+
from fvdb.utils.tests.grid_utils import make_grid_batch_and_jagged_point_data
1212

1313
all_device_dtype_combos = [
1414
["cuda", torch.float16],
@@ -28,15 +28,9 @@ def setUp(self):
2828
def test_morton_permutation(self, device, dtype):
2929
"""Test Morton order permutations (xyz and zyx variants)."""
3030
# Create a test grid batch with known active voxels
31-
batch_size = 2
32-
grid_size = 8
33-
3431
# Create a grid batch with some active voxels
35-
grid_batch = create_test_grid_batch(
36-
batch_size=batch_size,
37-
grid_size=grid_size,
38-
device=device,
39-
density=0.3 # 30% of voxels will be active
32+
grid_batch, _, _ = make_grid_batch_and_jagged_point_data(
33+
device=device, dtype=dtype, include_boundary_points=True
4034
)
4135

4236
# Get permutation indices for both Morton orderings
@@ -51,17 +45,17 @@ def test_morton_permutation(self, device, dtype):
5145
continue
5246

5347
# Extract permutation indices for this grid
54-
grid_perm = morton_perm.jdata[offset:offset + num_voxels].squeeze(-1)
55-
grid_perm_zyx = morton_zyx_perm.jdata[offset:offset + num_voxels].squeeze(-1)
48+
grid_perm = morton_perm.jdata[offset : offset + num_voxels].squeeze(-1)
49+
grid_perm_zyx = morton_zyx_perm.jdata[offset : offset + num_voxels].squeeze(-1)
5650

5751
# Verify permutations contain all indices
5852
expected_indices = torch.arange(offset, offset + num_voxels, device=device)
5953
self.assertTrue(torch.sort(grid_perm)[0].equal(expected_indices))
6054
self.assertTrue(torch.sort(grid_perm_zyx)[0].equal(expected_indices))
6155

6256
# Get Morton codes
63-
morton_codes = grid_batch.encode_morton().jdata[offset:offset + num_voxels]
64-
morton_zyx_codes = grid_batch.encode_morton_zyx().jdata[offset:offset + num_voxels]
57+
morton_codes = grid_batch.encode_morton().jdata[offset : offset + num_voxels]
58+
morton_zyx_codes = grid_batch.encode_morton_zyx().jdata[offset : offset + num_voxels]
6559

6660
# Verify codes are sorted after applying permutation
6761
sorted_codes = morton_codes[grid_perm - offset]
@@ -72,20 +66,13 @@ def test_morton_permutation(self, device, dtype):
7266

7367
offset += num_voxels
7468

75-
7669
@parameterized.expand(all_device_dtype_combos)
7770
def test_hilbert_permutation(self, device, dtype):
7871
"""Test Hilbert curve permutations (xyz and zyx variants)."""
7972
# Create a test grid batch with known active voxels
80-
batch_size = 2
81-
grid_size = 8
82-
8373
# Create a grid batch with some active voxels
84-
grid_batch = create_test_grid_batch(
85-
batch_size=batch_size,
86-
grid_size=grid_size,
87-
device=device,
88-
density=0.3 # 30% of voxels will be active
74+
grid_batch, _, _ = make_grid_batch_and_jagged_point_data(
75+
device=device, dtype=dtype, include_boundary_points=True
8976
)
9077

9178
# Get permutation indices for both Hilbert orderings
@@ -100,17 +87,17 @@ def test_hilbert_permutation(self, device, dtype):
10087
continue
10188

10289
# Extract permutation indices for this grid
103-
grid_perm = hilbert_perm.jdata[offset:offset + num_voxels].squeeze(-1)
104-
grid_perm_zyx = hilbert_zyx_perm.jdata[offset:offset + num_voxels].squeeze(-1)
90+
grid_perm = hilbert_perm.jdata[offset : offset + num_voxels].squeeze(-1)
91+
grid_perm_zyx = hilbert_zyx_perm.jdata[offset : offset + num_voxels].squeeze(-1)
10592

10693
# Verify permutations contain all indices
10794
expected_indices = torch.arange(offset, offset + num_voxels, device=device)
10895
self.assertTrue(torch.sort(grid_perm)[0].equal(expected_indices))
10996
self.assertTrue(torch.sort(grid_perm_zyx)[0].equal(expected_indices))
11097

11198
# Get Hilbert codes
112-
hilbert_codes = grid_batch.encode_hilbert().jdata[offset:offset + num_voxels]
113-
hilbert_zyx_codes = grid_batch.encode_hilbert_zyx().jdata[offset:offset + num_voxels]
99+
hilbert_codes = grid_batch.encode_hilbert().jdata[offset : offset + num_voxels]
100+
hilbert_zyx_codes = grid_batch.encode_hilbert_zyx().jdata[offset : offset + num_voxels]
114101

115102
# Verify codes are sorted after applying permutation
116103
sorted_codes = hilbert_codes[grid_perm - offset]
@@ -121,29 +108,21 @@ def test_hilbert_permutation(self, device, dtype):
121108

122109
offset += num_voxels
123110

124-
125111
@parameterized.expand(all_device_dtype_combos)
126112
def test_permutation_validity(self, device, dtype):
127113
"""Test that permutation indices are valid (complete and unique)."""
128-
batch_size = 3
129-
grid_size = 16
130-
131-
# Create grid batch with varying densities
132-
densities = [0.1, 0.5, 0.8] # Test different sparsity levels
133-
for density in densities:
134-
grid_batch = create_test_grid_batch(
135-
batch_size=batch_size,
136-
grid_size=grid_size,
137-
device=device,
138-
density=density
114+
# Create multiple grid batches to test different configurations
115+
for _ in range(3):
116+
grid_batch, _, _ = make_grid_batch_and_jagged_point_data(
117+
device=device, dtype=dtype, include_boundary_points=True
139118
)
140119

141120
# Test all permutation types
142121
permutations = [
143122
grid_batch.permutation_morton(),
144123
grid_batch.permutation_morton_zyx(),
145124
grid_batch.permutation_hilbert(),
146-
grid_batch.permutation_hilbert_zyx()
125+
grid_batch.permutation_hilbert_zyx(),
147126
]
148127

149128
for perm in permutations:
@@ -154,7 +133,7 @@ def test_permutation_validity(self, device, dtype):
154133
continue
155134

156135
# Extract permutation indices for this grid
157-
grid_perm = perm.jdata[offset:offset + num_voxels].squeeze(-1)
136+
grid_perm = perm.jdata[offset : offset + num_voxels].squeeze(-1)
158137

159138
# Verify indices are within valid range
160139
self.assertTrue(torch.all(grid_perm >= offset))
@@ -169,5 +148,6 @@ def test_permutation_validity(self, device, dtype):
169148

170149
offset += num_voxels
171150

151+
172152
if __name__ == "__main__":
173153
unittest.main()

0 commit comments

Comments
 (0)