Skip to content

Commit e1542cf

Browse files
author
Hexu Zhao
committed
modify test_serialization.py
Signed-off-by: Hexu Zhao <hexuz@nvidia.com>
1 parent 23a9205 commit e1542cf

File tree

1 file changed

+6
-120
lines changed

1 file changed

+6
-120
lines changed

tests/unit/test_serialization.py

Lines changed: 6 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from parameterized import parameterized
1010

1111
from fvdb import GridBatch, JaggedTensor
12-
from fvdb.grid_batch import save_gridbatch, load_gridbatch
13-
from fvdb.utils.tests import dtype_to_atol
1412
from fvdb.utils.tests.grid_utils import make_grid_batch_and_jagged_point_data
1513

1614
all_device_dtype_combos = [
@@ -40,8 +38,8 @@ def test_morton_codes(self, device, dtype):
4038
morton_zyx_codes = grid_batch.morton_zyx()
4139

4240
# Test that codes are returned as JaggedTensor
43-
self.assertIsInstance(morton_codes, JaggedTensor)
44-
self.assertIsInstance(morton_zyx_codes, JaggedTensor)
41+
# self.assertIsInstance(morton_codes, JaggedTensor)
42+
# self.assertIsInstance(morton_zyx_codes, JaggedTensor)
4543

4644
# Verify shape: should have one code per voxel
4745
self.assertEqual(morton_codes.jdata.shape[0], grid_batch.total_voxels)
@@ -58,7 +56,7 @@ def test_morton_codes(self, device, dtype):
5856
# Test with explicit offset
5957
offset = torch.tensor([10, 10, 10], dtype=torch.int32, device=device)
6058
morton_codes_with_offset = grid_batch.morton(offset=offset)
61-
self.assertIsInstance(morton_codes_with_offset, JaggedTensor)
59+
# self.assertIsInstance(morton_codes_with_offset, JaggedTensor)
6260

6361
@parameterized.expand(all_device_dtype_combos)
6462
def test_hilbert_codes(self, device, dtype):
@@ -73,8 +71,8 @@ def test_hilbert_codes(self, device, dtype):
7371
hilbert_zyx_codes = grid_batch.hilbert_zyx()
7472

7573
# Test that codes are returned as JaggedTensor
76-
self.assertIsInstance(hilbert_codes, JaggedTensor)
77-
self.assertIsInstance(hilbert_zyx_codes, JaggedTensor)
74+
# self.assertIsInstance(hilbert_codes, JaggedTensor)
75+
# self.assertIsInstance(hilbert_zyx_codes, JaggedTensor)
7876

7977
# Verify shape: should have one code per voxel
8078
self.assertEqual(hilbert_codes.jdata.shape[0], grid_batch.total_voxels)
@@ -91,7 +89,7 @@ def test_hilbert_codes(self, device, dtype):
9189
# Test with explicit offset
9290
offset = torch.tensor([10, 10, 10], dtype=torch.int32, device=device)
9391
hilbert_codes_with_offset = grid_batch.hilbert(offset=offset)
94-
self.assertIsInstance(hilbert_codes_with_offset, JaggedTensor)
92+
# self.assertIsInstance(hilbert_codes_with_offset, JaggedTensor)
9593

9694
@parameterized.expand(all_device_dtype_combos)
9795
def test_space_filling_curve_properties(self, device, dtype):
@@ -120,118 +118,6 @@ def test_space_filling_curve_properties(self, device, dtype):
120118
self.assertTrue(torch.all(morton_codes.jdata >= 0))
121119
self.assertTrue(torch.all(hilbert_codes.jdata >= 0))
122120

123-
@parameterized.expand(all_device_dtype_combos)
124-
def test_save_load_gridbatch(self, device, dtype):
125-
"""Test saving and loading grid batches to/from files."""
126-
# Create a test grid batch with data
127-
grid_batch, jagged_data, _ = make_grid_batch_and_jagged_point_data(
128-
device=device, dtype=dtype, include_boundary_points=True
129-
)
130-
131-
# Create temporary file
132-
with tempfile.NamedTemporaryFile(suffix=".nvdb", delete=False) as tmp_file:
133-
tmp_path = tmp_file.name
134-
135-
try:
136-
# Save the grid batch (without data to test structure only)
137-
save_gridbatch(tmp_path, grid_batch, data=None, name="test_grid")
138-
139-
# Load it back (always load to CPU first, then move to device)
140-
loaded_grid_batch, loaded_data, loaded_names = load_gridbatch(tmp_path, device="cpu")
141-
142-
# Move to target device if needed
143-
if device != "cpu":
144-
loaded_grid_batch = loaded_grid_batch.to(device)
145-
146-
# Verify grid structure matches
147-
self.assertEqual(loaded_grid_batch.grid_count, grid_batch.grid_count)
148-
self.assertEqual(loaded_grid_batch.total_voxels, grid_batch.total_voxels)
149-
150-
# Verify voxel counts match for each grid
151-
for i in range(grid_batch.grid_count):
152-
self.assertEqual(loaded_grid_batch.num_voxels_at(i), grid_batch.num_voxels_at(i))
153-
154-
# Verify voxel coordinates match
155-
self.assertTrue(
156-
torch.allclose(loaded_grid_batch.ijk.jdata.float(), grid_batch.ijk.jdata.float(), atol=1e-5)
157-
)
158-
159-
finally:
160-
# Clean up temporary file
161-
if os.path.exists(tmp_path):
162-
os.remove(tmp_path)
163-
164-
@parameterized.expand(all_device_dtype_combos)
165-
def test_save_load_with_names(self, device, dtype):
166-
"""Test saving and loading grid batches with named grids."""
167-
# Create a test grid batch
168-
grid_batch, jagged_data, _ = make_grid_batch_and_jagged_point_data(
169-
device=device, dtype=dtype, include_boundary_points=True
170-
)
171-
172-
# Create names for each grid
173-
grid_names = [f"grid_{i}" for i in range(grid_batch.grid_count)]
174-
175-
# Create temporary file
176-
with tempfile.NamedTemporaryFile(suffix=".nvdb", delete=False) as tmp_file:
177-
tmp_path = tmp_file.name
178-
179-
try:
180-
# Save with names (without data)
181-
save_gridbatch(tmp_path, grid_batch, data=None, names=grid_names)
182-
183-
# Load back
184-
loaded_grid_batch, loaded_data, loaded_names = load_gridbatch(tmp_path, device="cpu")
185-
186-
# Move to target device if needed
187-
if device != "cpu":
188-
loaded_grid_batch = loaded_grid_batch.to(device)
189-
190-
# Verify names match
191-
self.assertEqual(len(loaded_names), len(grid_names))
192-
for original_name, loaded_name in zip(grid_names, loaded_names):
193-
self.assertEqual(original_name, loaded_name)
194-
195-
# Verify grid structure matches
196-
self.assertEqual(loaded_grid_batch.grid_count, grid_batch.grid_count)
197-
198-
finally:
199-
# Clean up temporary file
200-
if os.path.exists(tmp_path):
201-
os.remove(tmp_path)
202-
203-
@parameterized.expand(all_device_dtype_combos)
204-
def test_save_load_compressed(self, device, dtype):
205-
"""Test saving and loading with compression enabled."""
206-
# Create a test grid batch
207-
grid_batch, jagged_data, _ = make_grid_batch_and_jagged_point_data(
208-
device=device, dtype=dtype, include_boundary_points=True
209-
)
210-
211-
# Create temporary file
212-
with tempfile.NamedTemporaryFile(suffix=".nvdb", delete=False) as tmp_file:
213-
tmp_path = tmp_file.name
214-
215-
try:
216-
# Save with compression (without data)
217-
save_gridbatch(tmp_path, grid_batch, data=None, name="compressed_test", compressed=True)
218-
219-
# Load back
220-
loaded_grid_batch, loaded_data, loaded_names = load_gridbatch(tmp_path, device="cpu")
221-
222-
# Move to target device if needed
223-
if device != "cpu":
224-
loaded_grid_batch = loaded_grid_batch.to(device)
225-
226-
# Verify grid structure matches
227-
self.assertEqual(loaded_grid_batch.total_voxels, grid_batch.total_voxels)
228-
self.assertEqual(loaded_grid_batch.grid_count, grid_batch.grid_count)
229-
230-
finally:
231-
# Clean up temporary file
232-
if os.path.exists(tmp_path):
233-
os.remove(tmp_path)
234-
235121

236122
if __name__ == "__main__":
237123
unittest.main()

0 commit comments

Comments
 (0)