Skip to content

Commit 7448d59

Browse files
committed
Replaced scatter_max and scatter_mean with the built-in torch tensor scatter
Signed-off-by: Christopher Horvath <[email protected]>
1 parent 5048b62 commit 7448d59

File tree

2 files changed

+228
-5
lines changed

2 files changed

+228
-5
lines changed

surface_reconstruction/nksr/nksr_fvdb/nn/point_encoder.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import torch.nn as tnn
77
from fvdb.types import DeviceIdentifier, resolve_device
8-
from torch_scatter import scatter_max, scatter_mean
98

109
from .resnet_block import ResnetBlockFC
1110

@@ -220,7 +219,10 @@ def forward(
220219
# Take the values x, which are the network-transformed features only at the valid
221220
# feature indices, and scatter them via max (which removes order dependence) into an
222221
# output vector that's the same size as the original voxel grid.
223-
pooled, _ = scatter_max(x, valid_gvoxel_indices, dim=0, dim_size=grid.total_voxels)
222+
# Using PyTorch's native scatter_reduce_ with "amax" reduction.
223+
pooled = torch.zeros(grid.total_voxels, self.size_hidden, dtype=x.dtype, device=x.device)
224+
expanded_indices = valid_gvoxel_indices.unsqueeze(1).expand(-1, self.size_hidden)
225+
pooled.scatter_reduce_(0, expanded_indices, x, reduce="amax", include_self=False)
224226
assert pooled.ndim == 2
225227
assert pooled.shape[0] == grid.total_voxels
226228
assert pooled.shape[1] == self.size_hidden
@@ -258,7 +260,11 @@ def forward(
258260
# Finally, we need to scatter the features back into their corresponding voxel indices.
259261
# Because this is a mean pooling, the ordering doesn't matter.
260262
# Note: Voxels with no overlapping input points will have zero features.
261-
x = scatter_mean(x, valid_gvoxel_indices, dim=0, dim_size=grid.total_voxels)
263+
# Using PyTorch's native scatter_reduce_ with "mean" reduction.
264+
expanded_indices = valid_gvoxel_indices.unsqueeze(1).expand(-1, self.size_output)
265+
result = torch.zeros(grid.total_voxels, self.size_output, dtype=x.dtype, device=x.device)
266+
result.scatter_reduce_(0, expanded_indices, x, reduce="mean", include_self=False)
267+
x = result
262268
assert x.ndim == 2
263269
assert x.shape[0] == grid.total_voxels
264270
assert x.shape[1] == self.size_output

surface_reconstruction/tests/unit/test_point_encoder.py

Lines changed: 219 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,11 @@ class TestPointEncoderDeterminism(unittest.TestCase):
654654

655655
@parameterized.expand(all_devices)
656656
def test_same_input_produces_same_output(self, device: str) -> None:
657-
"""Test that the same input produces identical output."""
657+
"""Test that the same input produces essentially identical output.
658+
659+
Note: PyTorch's scatter_reduce_ can be non-deterministic on CUDA devices,
660+
so we use allclose to allow for minor floating-point variations.
661+
"""
658662
points_world = generate_street_scene_batch(
659663
batch_size=2,
660664
base_seed=42,
@@ -672,7 +676,11 @@ def test_same_input_produces_same_output(self, device: str) -> None:
672676
features1 = encoder(points_voxel, None, grid)
673677
features2 = encoder(points_voxel, None, grid)
674678

675-
self.assertTrue(torch.equal(features1.jdata, features2.jdata))
679+
# Use allclose to handle potential CUDA non-determinism in scatter operations
680+
self.assertTrue(
681+
torch.allclose(features1.jdata, features2.jdata, atol=1e-5, rtol=1e-5),
682+
"Same input should produce essentially identical output",
683+
)
676684

677685
@parameterized.expand(all_devices)
678686
def test_different_seeds_produce_different_outputs(self, device: str) -> None:
@@ -826,6 +834,215 @@ def test_repr_contains_hyperparameters(self, device: str) -> None:
826834
self.assertIn("block_count=5", s)
827835

828836

837+
class TestPointEncoderScatterReduction(unittest.TestCase):
838+
"""Tests for scatter reduction operations (max pooling and mean pooling).
839+
840+
These tests specifically verify that the scatter operations correctly
841+
aggregate multiple points per voxel. This is critical for the PointNet-style
842+
architecture where many points can fall into the same voxel.
843+
"""
844+
845+
@parameterized.expand(all_devices)
846+
def test_multiple_points_per_voxel_produces_valid_output(self, device: str) -> None:
847+
"""Test that voxels with multiple points produce finite, non-zero features."""
848+
# Create points where multiple points fall into the same voxel
849+
# Points at [0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3] all map to voxel [0,0,0]
850+
# Points at [1.1, 1.1, 1.1], [1.2, 1.2, 1.2] map to voxel [1,1,1]
851+
points = fvdb.JaggedTensor.from_list_of_tensors(
852+
[
853+
torch.tensor(
854+
[
855+
[0.1, 0.1, 0.1],
856+
[0.2, 0.2, 0.2],
857+
[0.3, 0.3, 0.3],
858+
[1.1, 1.1, 1.1],
859+
[1.2, 1.2, 1.2],
860+
],
861+
device=device,
862+
)
863+
]
864+
)
865+
grid = fvdb.GridBatch.from_points(points, voxel_sizes=1, origins=0)
866+
867+
encoder = PointEncoder(size_feature=3, size_output=32, device=device)
868+
869+
with torch.no_grad():
870+
features = encoder(points, None, grid)
871+
872+
# Should have exactly 2 voxels with valid features
873+
self.assertEqual(features.jdata.shape[0], grid.total_voxels)
874+
self.assertTrue(torch.isfinite(features.jdata).all())
875+
# Features should not be all zeros (since we have points in all voxels)
876+
self.assertTrue((features.jdata != 0).any())
877+
878+
@parameterized.expand(all_devices)
879+
def test_varying_points_per_voxel(self, device: str) -> None:
880+
"""Test encoder with varying number of points per voxel.
881+
882+
This test creates multiple points that should cluster into a small number
883+
of voxels (fewer than the total number of points), exercising the
884+
many-to-one scatter aggregation.
885+
"""
886+
# Create points clustered around certain locations
887+
points = fvdb.JaggedTensor.from_list_of_tensors(
888+
[
889+
torch.tensor(
890+
[
891+
# Cluster 1: 1 point
892+
[0.5, 0.5, 0.5],
893+
# Cluster 2: 3 points (all in same voxel due to small offsets)
894+
[10.1, 10.1, 10.1],
895+
[10.2, 10.2, 10.2],
896+
[10.3, 10.3, 10.3],
897+
# Cluster 3: 5 points (all in same voxel)
898+
[20.1, 20.1, 20.1],
899+
[20.2, 20.2, 20.2],
900+
[20.3, 20.3, 20.3],
901+
[20.4, 20.4, 20.4],
902+
[20.5, 20.5, 20.5],
903+
],
904+
device=device,
905+
)
906+
]
907+
)
908+
grid = fvdb.GridBatch.from_points(points, voxel_sizes=1, origins=0)
909+
910+
encoder = PointEncoder(size_feature=3, size_output=32, device=device)
911+
912+
with torch.no_grad():
913+
features = encoder(points, None, grid)
914+
915+
# Key assertion: there should be fewer voxels than input points
916+
# (demonstrating many-to-one aggregation)
917+
num_points = points.jdata.shape[0]
918+
self.assertLess(grid.total_voxels, num_points, "Should have fewer voxels than input points")
919+
self.assertEqual(features.jdata.shape[0], grid.total_voxels)
920+
self.assertTrue(torch.isfinite(features.jdata).all())
921+
# At least some voxels should have non-zero features
922+
self.assertTrue((features.jdata != 0).any(), "At least some features should be non-zero")
923+
924+
@parameterized.expand(all_devices)
925+
def test_dense_points_many_per_voxel(self, device: str) -> None:
926+
"""Test with very dense point cloud - 100 points in a small region."""
927+
# Create 100 random points within a small region
928+
# Use small offsets from a base point to ensure they all fall in the same voxel
929+
torch.manual_seed(42)
930+
base = torch.tensor([[10.0, 10.0, 10.0]], device=device)
931+
offsets = torch.rand(100, 3, device=device) * 0.5 # Small offsets in [0, 0.5)
932+
single_voxel_points = base + offsets
933+
934+
points = fvdb.JaggedTensor.from_list_of_tensors([single_voxel_points])
935+
grid = fvdb.GridBatch.from_points(points, voxel_sizes=1, origins=0)
936+
937+
encoder = PointEncoder(size_feature=3, size_output=32, device=device)
938+
939+
with torch.no_grad():
940+
features = encoder(points, None, grid)
941+
942+
# Should have exactly 1 voxel (all points in the same voxel)
943+
self.assertEqual(grid.total_voxels, 1)
944+
self.assertEqual(features.jdata.shape[0], 1)
945+
self.assertTrue(torch.isfinite(features.jdata).all())
946+
947+
@parameterized.expand(all_devices)
948+
def test_scatter_consistency_with_duplicates(self, device: str) -> None:
949+
"""Test that scatter operations produce consistent results with multiple points per voxel.
950+
951+
Note: PyTorch's scatter_reduce_ can be non-deterministic on CUDA devices due to
952+
atomic operation ordering. We use allclose to allow for minor floating-point
953+
variations while ensuring results are essentially the same.
954+
"""
955+
points = fvdb.JaggedTensor.from_list_of_tensors(
956+
[
957+
torch.tensor(
958+
[
959+
[10.1, 10.1, 10.1],
960+
[10.5, 10.5, 10.5],
961+
[10.9, 10.9, 10.9],
962+
[20.1, 20.1, 20.1],
963+
[20.5, 20.5, 20.5],
964+
],
965+
device=device,
966+
)
967+
]
968+
)
969+
grid = fvdb.GridBatch.from_points(points, voxel_sizes=1, origins=0)
970+
971+
encoder = PointEncoder(size_feature=3, size_output=32, device=device)
972+
973+
with torch.no_grad():
974+
features1 = encoder(points, None, grid)
975+
features2 = encoder(points, None, grid)
976+
features3 = encoder(points, None, grid)
977+
978+
# All runs should produce essentially identical results
979+
# Use allclose to handle potential CUDA non-determinism in scatter operations
980+
self.assertTrue(
981+
torch.allclose(features1.jdata, features2.jdata, atol=1e-5, rtol=1e-5),
982+
"Features from run 1 and 2 should be essentially identical",
983+
)
984+
self.assertTrue(
985+
torch.allclose(features2.jdata, features3.jdata, atol=1e-5, rtol=1e-5),
986+
"Features from run 2 and 3 should be essentially identical",
987+
)
988+
989+
@parameterized.expand(all_devices)
990+
def test_gradients_flow_with_multiple_points_per_voxel(self, device: str) -> None:
991+
"""Test that gradients flow correctly when multiple points map to same voxel."""
992+
points = fvdb.JaggedTensor.from_list_of_tensors(
993+
[
994+
torch.tensor(
995+
[
996+
[0.1, 0.1, 0.1],
997+
[0.5, 0.5, 0.5],
998+
[0.9, 0.9, 0.9], # 3 points in voxel [0,0,0]
999+
[1.5, 1.5, 1.5], # 1 point in voxel [1,1,1]
1000+
],
1001+
device=device,
1002+
)
1003+
]
1004+
)
1005+
grid = fvdb.GridBatch.from_points(points, voxel_sizes=1, origins=0)
1006+
1007+
encoder = PointEncoder(size_feature=3, size_output=32, device=device)
1008+
1009+
# Forward and backward pass
1010+
features = encoder(points, None, grid)
1011+
loss = features.jdata.square().mean()
1012+
loss.backward()
1013+
1014+
# All parameters should have gradients
1015+
for name, param in encoder.named_parameters():
1016+
self.assertIsNotNone(param.grad, f"Gradient for {name} is None")
1017+
assert param.grad is not None
1018+
self.assertTrue(
1019+
torch.isfinite(param.grad).all(),
1020+
f"Gradient for {name} contains non-finite values",
1021+
)
1022+
1023+
@parameterized.expand(all_device_batch_combos)
1024+
def test_batched_with_varying_density(self, device: str, batch_size: int) -> None:
1025+
"""Test batched input where different batches have different point densities."""
1026+
tensors = []
1027+
for i in range(batch_size):
1028+
# Each batch item has (i+1)*10 points, creating varying density
1029+
num_points = (i + 1) * 10
1030+
torch.manual_seed(42 + i)
1031+
pts = torch.rand(num_points, 3, device=device) * 3 # Points in [0, 3) range
1032+
tensors.append(pts)
1033+
1034+
points = fvdb.JaggedTensor.from_list_of_tensors(tensors)
1035+
grid = fvdb.GridBatch.from_points(points, voxel_sizes=1, origins=0)
1036+
1037+
encoder = PointEncoder(size_feature=3, size_output=32, device=device)
1038+
1039+
with torch.no_grad():
1040+
features = encoder(points, None, grid)
1041+
1042+
self.assertEqual(features.jdata.shape[0], grid.total_voxels)
1043+
self.assertTrue(torch.isfinite(features.jdata).all())
1044+
1045+
8291046
class TestPointEncoderExtraFeatures(unittest.TestCase):
8301047
"""Tests for extra feature handling (colors, normals, etc.)."""
8311048

0 commit comments

Comments
 (0)