@@ -654,7 +654,11 @@ class TestPointEncoderDeterminism(unittest.TestCase):
654654
655655 @parameterized .expand (all_devices )
656656 def test_same_input_produces_same_output (self , device : str ) -> None :
657- """Test that the same input produces identical output."""
657+ """Test that the same input produces essentially identical output.
658+
659+ Note: PyTorch's scatter_reduce_ can be non-deterministic on CUDA devices,
660+ so we use allclose to allow for minor floating-point variations.
661+ """
658662 points_world = generate_street_scene_batch (
659663 batch_size = 2 ,
660664 base_seed = 42 ,
@@ -672,7 +676,11 @@ def test_same_input_produces_same_output(self, device: str) -> None:
672676 features1 = encoder (points_voxel , None , grid )
673677 features2 = encoder (points_voxel , None , grid )
674678
675- self .assertTrue (torch .equal (features1 .jdata , features2 .jdata ))
679+ # Use allclose to handle potential CUDA non-determinism in scatter operations
680+ self .assertTrue (
681+ torch .allclose (features1 .jdata , features2 .jdata , atol = 1e-5 , rtol = 1e-5 ),
682+ "Same input should produce essentially identical output" ,
683+ )
676684
677685 @parameterized .expand (all_devices )
678686 def test_different_seeds_produce_different_outputs (self , device : str ) -> None :
@@ -826,6 +834,215 @@ def test_repr_contains_hyperparameters(self, device: str) -> None:
826834 self .assertIn ("block_count=5" , s )
827835
828836
837+ class TestPointEncoderScatterReduction (unittest .TestCase ):
838+ """Tests for scatter reduction operations (max pooling and mean pooling).
839+
840+ These tests specifically verify that the scatter operations correctly
841+ aggregate multiple points per voxel. This is critical for the PointNet-style
842+ architecture where many points can fall into the same voxel.
843+ """
844+
845+ @parameterized .expand (all_devices )
846+ def test_multiple_points_per_voxel_produces_valid_output (self , device : str ) -> None :
847+ """Test that voxels with multiple points produce finite, non-zero features."""
848+ # Create points where multiple points fall into the same voxel
849+ # Points at [0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3] all map to voxel [0,0,0]
850+ # Points at [1.1, 1.1, 1.1], [1.2, 1.2, 1.2] map to voxel [1,1,1]
851+ points = fvdb .JaggedTensor .from_list_of_tensors (
852+ [
853+ torch .tensor (
854+ [
855+ [0.1 , 0.1 , 0.1 ],
856+ [0.2 , 0.2 , 0.2 ],
857+ [0.3 , 0.3 , 0.3 ],
858+ [1.1 , 1.1 , 1.1 ],
859+ [1.2 , 1.2 , 1.2 ],
860+ ],
861+ device = device ,
862+ )
863+ ]
864+ )
865+ grid = fvdb .GridBatch .from_points (points , voxel_sizes = 1 , origins = 0 )
866+
867+ encoder = PointEncoder (size_feature = 3 , size_output = 32 , device = device )
868+
869+ with torch .no_grad ():
870+ features = encoder (points , None , grid )
871+
872+ # Should have exactly 2 voxels with valid features
873+ self .assertEqual (features .jdata .shape [0 ], grid .total_voxels )
874+ self .assertTrue (torch .isfinite (features .jdata ).all ())
875+ # Features should not be all zeros (since we have points in all voxels)
876+ self .assertTrue ((features .jdata != 0 ).any ())
877+
878+ @parameterized .expand (all_devices )
879+ def test_varying_points_per_voxel (self , device : str ) -> None :
880+ """Test encoder with varying number of points per voxel.
881+
882+ This test creates multiple points that should cluster into a small number
883+ of voxels (fewer than the total number of points), exercising the
884+ many-to-one scatter aggregation.
885+ """
886+ # Create points clustered around certain locations
887+ points = fvdb .JaggedTensor .from_list_of_tensors (
888+ [
889+ torch .tensor (
890+ [
891+ # Cluster 1: 1 point
892+ [0.5 , 0.5 , 0.5 ],
893+ # Cluster 2: 3 points (all in same voxel due to small offsets)
894+ [10.1 , 10.1 , 10.1 ],
895+ [10.2 , 10.2 , 10.2 ],
896+ [10.3 , 10.3 , 10.3 ],
897+ # Cluster 3: 5 points (all in same voxel)
898+ [20.1 , 20.1 , 20.1 ],
899+ [20.2 , 20.2 , 20.2 ],
900+ [20.3 , 20.3 , 20.3 ],
901+ [20.4 , 20.4 , 20.4 ],
902+ [20.5 , 20.5 , 20.5 ],
903+ ],
904+ device = device ,
905+ )
906+ ]
907+ )
908+ grid = fvdb .GridBatch .from_points (points , voxel_sizes = 1 , origins = 0 )
909+
910+ encoder = PointEncoder (size_feature = 3 , size_output = 32 , device = device )
911+
912+ with torch .no_grad ():
913+ features = encoder (points , None , grid )
914+
915+ # Key assertion: there should be fewer voxels than input points
916+ # (demonstrating many-to-one aggregation)
917+ num_points = points .jdata .shape [0 ]
918+ self .assertLess (grid .total_voxels , num_points , "Should have fewer voxels than input points" )
919+ self .assertEqual (features .jdata .shape [0 ], grid .total_voxels )
920+ self .assertTrue (torch .isfinite (features .jdata ).all ())
921+ # At least some voxels should have non-zero features
922+ self .assertTrue ((features .jdata != 0 ).any (), "At least some features should be non-zero" )
923+
924+ @parameterized .expand (all_devices )
925+ def test_dense_points_many_per_voxel (self , device : str ) -> None :
926+ """Test with very dense point cloud - 100 points in a small region."""
927+ # Create 100 random points within a small region
928+ # Use small offsets from a base point to ensure they all fall in the same voxel
929+ torch .manual_seed (42 )
930+ base = torch .tensor ([[10.0 , 10.0 , 10.0 ]], device = device )
931+ offsets = torch .rand (100 , 3 , device = device ) * 0.5 # Small offsets in [0, 0.5)
932+ single_voxel_points = base + offsets
933+
934+ points = fvdb .JaggedTensor .from_list_of_tensors ([single_voxel_points ])
935+ grid = fvdb .GridBatch .from_points (points , voxel_sizes = 1 , origins = 0 )
936+
937+ encoder = PointEncoder (size_feature = 3 , size_output = 32 , device = device )
938+
939+ with torch .no_grad ():
940+ features = encoder (points , None , grid )
941+
942+ # Should have exactly 1 voxel (all points in the same voxel)
943+ self .assertEqual (grid .total_voxels , 1 )
944+ self .assertEqual (features .jdata .shape [0 ], 1 )
945+ self .assertTrue (torch .isfinite (features .jdata ).all ())
946+
947+ @parameterized .expand (all_devices )
948+ def test_scatter_consistency_with_duplicates (self , device : str ) -> None :
949+ """Test that scatter operations produce consistent results with multiple points per voxel.
950+
951+ Note: PyTorch's scatter_reduce_ can be non-deterministic on CUDA devices due to
952+ atomic operation ordering. We use allclose to allow for minor floating-point
953+ variations while ensuring results are essentially the same.
954+ """
955+ points = fvdb .JaggedTensor .from_list_of_tensors (
956+ [
957+ torch .tensor (
958+ [
959+ [10.1 , 10.1 , 10.1 ],
960+ [10.5 , 10.5 , 10.5 ],
961+ [10.9 , 10.9 , 10.9 ],
962+ [20.1 , 20.1 , 20.1 ],
963+ [20.5 , 20.5 , 20.5 ],
964+ ],
965+ device = device ,
966+ )
967+ ]
968+ )
969+ grid = fvdb .GridBatch .from_points (points , voxel_sizes = 1 , origins = 0 )
970+
971+ encoder = PointEncoder (size_feature = 3 , size_output = 32 , device = device )
972+
973+ with torch .no_grad ():
974+ features1 = encoder (points , None , grid )
975+ features2 = encoder (points , None , grid )
976+ features3 = encoder (points , None , grid )
977+
978+ # All runs should produce essentially identical results
979+ # Use allclose to handle potential CUDA non-determinism in scatter operations
980+ self .assertTrue (
981+ torch .allclose (features1 .jdata , features2 .jdata , atol = 1e-5 , rtol = 1e-5 ),
982+ "Features from run 1 and 2 should be essentially identical" ,
983+ )
984+ self .assertTrue (
985+ torch .allclose (features2 .jdata , features3 .jdata , atol = 1e-5 , rtol = 1e-5 ),
986+ "Features from run 2 and 3 should be essentially identical" ,
987+ )
988+
989+ @parameterized .expand (all_devices )
990+ def test_gradients_flow_with_multiple_points_per_voxel (self , device : str ) -> None :
991+ """Test that gradients flow correctly when multiple points map to same voxel."""
992+ points = fvdb .JaggedTensor .from_list_of_tensors (
993+ [
994+ torch .tensor (
995+ [
996+ [0.1 , 0.1 , 0.1 ],
997+ [0.5 , 0.5 , 0.5 ],
998+ [0.9 , 0.9 , 0.9 ], # 3 points in voxel [0,0,0]
999+ [1.5 , 1.5 , 1.5 ], # 1 point in voxel [1,1,1]
1000+ ],
1001+ device = device ,
1002+ )
1003+ ]
1004+ )
1005+ grid = fvdb .GridBatch .from_points (points , voxel_sizes = 1 , origins = 0 )
1006+
1007+ encoder = PointEncoder (size_feature = 3 , size_output = 32 , device = device )
1008+
1009+ # Forward and backward pass
1010+ features = encoder (points , None , grid )
1011+ loss = features .jdata .square ().mean ()
1012+ loss .backward ()
1013+
1014+ # All parameters should have gradients
1015+ for name , param in encoder .named_parameters ():
1016+ self .assertIsNotNone (param .grad , f"Gradient for { name } is None" )
1017+ assert param .grad is not None
1018+ self .assertTrue (
1019+ torch .isfinite (param .grad ).all (),
1020+ f"Gradient for { name } contains non-finite values" ,
1021+ )
1022+
1023+ @parameterized .expand (all_device_batch_combos )
1024+ def test_batched_with_varying_density (self , device : str , batch_size : int ) -> None :
1025+ """Test batched input where different batches have different point densities."""
1026+ tensors = []
1027+ for i in range (batch_size ):
1028+ # Each batch item has (i+1)*10 points, creating varying density
1029+ num_points = (i + 1 ) * 10
1030+ torch .manual_seed (42 + i )
1031+ pts = torch .rand (num_points , 3 , device = device ) * 3 # Points in [0, 3) range
1032+ tensors .append (pts )
1033+
1034+ points = fvdb .JaggedTensor .from_list_of_tensors (tensors )
1035+ grid = fvdb .GridBatch .from_points (points , voxel_sizes = 1 , origins = 0 )
1036+
1037+ encoder = PointEncoder (size_feature = 3 , size_output = 32 , device = device )
1038+
1039+ with torch .no_grad ():
1040+ features = encoder (points , None , grid )
1041+
1042+ self .assertEqual (features .jdata .shape [0 ], grid .total_voxels )
1043+ self .assertTrue (torch .isfinite (features .jdata ).all ())
1044+
1045+
8291046class TestPointEncoderExtraFeatures (unittest .TestCase ):
8301047 """Tests for extra feature handling (colors, normals, etc.)."""
8311048
0 commit comments