22import torch
33
44import pyg_lib
5+ from pyg_lib .testing import withCUDA
56
67
8+ @withCUDA
79@pytest .mark .parametrize ('dtype' , [torch .float , torch .double ])
8- def test_grid_cluster_2d (dtype : torch .dtype ) -> None :
10+ def test_grid_cluster_2d (dtype : torch .dtype , device : torch . device ) -> None :
911 pos = torch .tensor (
1012 [[0.0 , 0.0 ], [0.1 , 0.1 ], [0.5 , 0.5 ], [1.0 , 1.0 ], [1.1 , 1.1 ]],
11- dtype = dtype )
12- size = torch .tensor ([0.5 , 0.5 ], dtype = dtype )
13+ dtype = dtype , device = device )
14+ size = torch .tensor ([0.5 , 0.5 ], dtype = dtype , device = device )
1315
1416 out = pyg_lib .ops .grid_cluster (pos , size )
1517
@@ -21,37 +23,41 @@ def test_grid_cluster_2d(dtype: torch.dtype) -> None:
2123 assert out [3 ] == out [4 ]
2224
2325
26+ @withCUDA
2427@pytest .mark .parametrize ('dtype' , [torch .float , torch .double ])
25- def test_grid_cluster_3d (dtype : torch .dtype ) -> None :
28+ def test_grid_cluster_3d (dtype : torch .dtype , device : torch . device ) -> None :
2629 pos = torch .tensor ([[0.0 , 0.0 , 0.0 ], [0.1 , 0.1 , 0.1 ], [1.0 , 1.0 , 1.0 ]],
27- dtype = dtype )
28- size = torch .tensor ([0.5 , 0.5 , 0.5 ], dtype = dtype )
30+ dtype = dtype , device = device )
31+ size = torch .tensor ([0.5 , 0.5 , 0.5 ], dtype = dtype , device = device )
2932
3033 out = pyg_lib .ops .grid_cluster (pos , size )
3134
3235 assert out [0 ] == out [1 ]
3336 assert out [0 ] != out [2 ]
3437
3538
39+ @withCUDA
3640@pytest .mark .parametrize ('dtype' , [torch .float , torch .double ])
37- def test_grid_cluster_with_start_end (dtype : torch .dtype ) -> None :
38- pos = torch .tensor ([[0.0 , 0.0 ], [0.5 , 0.5 ], [1.0 , 1.0 ]], dtype = dtype )
39- size = torch .tensor ([0.5 , 0.5 ], dtype = dtype )
40- start = torch .tensor ([0.0 , 0.0 ], dtype = dtype )
41- end = torch .tensor ([1.0 , 1.0 ], dtype = dtype )
41+ def test_grid_cluster_with_start_end (dtype : torch .dtype ,
42+ device : torch .device ) -> None :
43+ pos = torch .tensor ([[0.0 , 0.0 ], [0.5 , 0.5 ], [1.0 , 1.0 ]], dtype = dtype ,
44+ device = device )
45+ size = torch .tensor ([0.5 , 0.5 ], dtype = dtype , device = device )
46+ start = torch .tensor ([0.0 , 0.0 ], dtype = dtype , device = device )
47+ end = torch .tensor ([1.0 , 1.0 ], dtype = dtype , device = device )
4248
4349 out = pyg_lib .ops .grid_cluster (pos , size , start , end )
4450
4551 assert out .shape == (3 , )
4652 assert out .dtype == torch .long
4753
4854
49- def test_grid_cluster_defaults_match_explicit () -> None :
55+ @withCUDA
56+ def test_grid_cluster_cpu_cuda_parity (device : torch .device ) -> None :
5057 pos = torch .tensor ([[0.0 , 0.0 ], [0.5 , 0.5 ], [1.0 , 1.0 ]])
5158 size = torch .tensor ([0.5 , 0.5 ])
5259
53- out_default = pyg_lib .ops .grid_cluster (pos , size )
54- out_explicit = pyg_lib .ops .grid_cluster (pos , size , start = pos .min (0 ).values ,
55- end = pos .max (0 ).values )
60+ out_cpu = pyg_lib .ops .grid_cluster (pos , size )
61+ out_dev = pyg_lib .ops .grid_cluster (pos .to (device ), size .to (device ))
5662
57- assert torch .equal (out_default , out_explicit )
63+ assert torch .equal (out_cpu , out_dev . cpu () )
0 commit comments