33import torch
44
55from mmcv .ops import three_interpolate
6+ from mmcv .utils import IS_CUDA_AVAILABLE , IS_NPU_AVAILABLE
67
78
8- @pytest .mark .skipif (
9- not torch .cuda .is_available (), reason = 'requires CUDA support' )
10- @pytest .mark .parametrize ('dtype' , [torch .half , torch .float , torch .double ])
11- def test_three_interpolate (dtype ):
9+ @pytest .mark .parametrize ('dtype' , [
10+ torch .half , torch .float ,
11+ pytest .param (
12+ torch .double ,
13+ marks = pytest .mark .skipif (
14+ IS_NPU_AVAILABLE ,
15+ reason = 'NPU does not support for 64-bit floating point' ))
16+ ])
17+ @pytest .mark .parametrize ('device' , [
18+ pytest .param (
19+ 'cuda' ,
20+ marks = pytest .mark .skipif (
21+ not IS_CUDA_AVAILABLE , reason = 'requires CUDA support' )),
22+ pytest .param (
23+ 'npu' ,
24+ marks = pytest .mark .skipif (
25+ not IS_NPU_AVAILABLE , reason = 'requires NPU support' ))
26+ ])
27+ def test_three_interpolate (dtype , device ):
1228 features = torch .tensor (
1329 [[[2.4350 , 4.7516 , 4.4995 , 2.4350 , 2.4350 , 2.4350 ],
1430 [3.1236 , 2.6278 , 3.0447 , 3.1236 , 3.1236 , 3.1236 ],
@@ -20,12 +36,13 @@ def test_three_interpolate(dtype):
2036 [0.0000 , 0.2744 , 2.0842 , 0.0000 , 0.0000 , 0.0000 ],
2137 [0.3414 , 1.5063 , 1.6209 , 0.3414 , 0.3414 , 0.3414 ],
2238 [0.5814 , 0.0103 , 0.0000 , 0.5814 , 0.5814 , 0.5814 ]]],
23- dtype = dtype ).cuda ()
39+ dtype = dtype ,
40+ device = device )
2441
25- idx = torch .tensor ([[[ 0 , 1 , 2 ], [ 2 , 3 , 4 ], [ 2 , 3 , 4 ], [ 0 , 1 , 2 ], [ 0 , 1 , 2 ],
26- [0 , 1 , 3 ]],
27- [[0 , 2 , 3 ], [1 , 3 , 4 ], [2 , 1 , 4 ], [0 , 2 , 4 ], [0 , 2 , 4 ],
28- [ 0 , 1 , 2 ]]] ).int (). cuda ()
42+ idx = torch .tensor (
43+ [[[ 0 , 1 , 2 ], [ 2 , 3 , 4 ], [ 2 , 3 , 4 ], [ 0 , 1 , 2 ], [ 0 , 1 , 2 ], [0 , 1 , 3 ]],
44+ [[0 , 2 , 3 ], [1 , 3 , 4 ], [2 , 1 , 4 ], [0 , 2 , 4 ], [0 , 2 , 4 ], [ 0 , 1 , 2 ]] ],
45+ device = device ).int ()
2946
3047 weight = torch .tensor ([[[3.3333e-01 , 3.3333e-01 , 3.3333e-01 ],
3148 [1.0000e+00 , 5.8155e-08 , 2.2373e-08 ],
@@ -39,7 +56,8 @@ def test_three_interpolate(dtype):
3956 [3.3333e-01 , 3.3333e-01 , 3.3333e-01 ],
4057 [3.3333e-01 , 3.3333e-01 , 3.3333e-01 ],
4158 [3.3333e-01 , 3.3333e-01 , 3.3333e-01 ]]],
42- dtype = dtype ).cuda ()
59+ dtype = dtype ,
60+ device = device )
4361
4462 output = three_interpolate (features , idx , weight )
4563 expected_output = torch .tensor ([[[
@@ -73,6 +91,7 @@ def test_three_interpolate(dtype):
7391 3.8760e-01 , 1.0300e-02 , 8.3569e-09 ,
7492 3.8760e-01 , 3.8760e-01 , 1.9723e-01
7593 ]]],
76- dtype = dtype ).cuda ()
94+ dtype = dtype ,
95+ device = device )
7796
7897 assert torch .allclose (output , expected_output , 1e-3 , 1e-4 )
0 commit comments