@@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device):
7272 ], [
7373 2.2060e-01 , 3.4110e-01 , 3.4110e-01 , 2.2060e-01 , 2.2060e-01 , 2.1380e-01
7474 ]],
75- [[
76- 8.1773e-01 , 9.5440e-01 , 2.4532e+00 ,
77- 8.1773e-01 , 8.1773e-01 , 1.1359e+00
78- ],
79- [
80- 8.4689e-01 , 1.9176e+00 , 1.4715e+00 ,
81- 8.4689e-01 , 8.4689e-01 , 1.3079e+00
82- ],
83- [
84- 6.9473e-01 , 2.7440e-01 , 2.0842e+00 ,
85- 6.9473e-01 , 6.9473e-01 , 7.8619e-01
86- ],
87- [
88- 7.6789e-01 , 1.5063e+00 , 1.6209e+00 ,
89- 7.6789e-01 , 7.6789e-01 , 1.1562e+00
90- ],
91- [
92- 3.8760e-01 , 1.0300e-02 , 8.3569e-09 ,
93- 3.8760e-01 , 3.8760e-01 , 1.9723e-01
94- ]]],
95- dtype = dtype ,
96- device = device )
75+ [[
76+ 8.1773e-01 , 9.5440e-01 , 2.4532e+00 ,
77+ 8.1773e-01 , 8.1773e-01 , 1.1359e+00
78+ ],
79+ [
80+ 8.4689e-01 , 1.9176e+00 , 1.4715e+00 ,
81+ 8.4689e-01 , 8.4689e-01 , 1.3079e+00
82+ ],
83+ [
84+ 6.9473e-01 , 2.7440e-01 , 2.0842e+00 ,
85+ 6.9473e-01 , 6.9473e-01 , 7.8619e-01
86+ ],
87+ [
88+ 7.6789e-01 , 1.5063e+00 , 1.6209e+00 ,
89+ 7.6789e-01 , 7.6789e-01 , 1.1562e+00
90+ ],
91+ [
92+ 3.8760e-01 , 1.0300e-02 , 8.3569e-09 ,
93+ 3.8760e-01 , 3.8760e-01 , 1.9723e-01
94+ ]]],
95+ dtype = dtype ,
96+ device = device )
9797
9898 assert torch .allclose (output , expected_output , 1e-3 , 1e-4 )
9999
@@ -106,16 +106,17 @@ def three_interpolate_forward_gloden(features, idx, weight):
106106 if dtype == np .float16 :
107107 features = features .astype (np .float32 )
108108 weight = weight .astype (np .float32 )
109-
110- output = np .zeros ((bs , cs , ns ), dtype = np .float )
109+
110+ output = np .zeros ((bs , cs , ns ), dtype = np .float )
111111 for b in range (bs ):
112112 for c in range (cs ):
113113 for n in range (ns ):
114114 output [b ][c ][n ] = features [b ][c ][idx [b ][n ][0 ]] * weight [b ][n ][0 ] \
115- + features [b ][c ][idx [b ][n ][1 ]] * weight [b ][n ][1 ] \
116- + features [b ][c ][idx [b ][n ][2 ]] * weight [b ][n ][2 ]
115+ + features [b ][c ][idx [b ][n ][1 ]] * weight [b ][n ][1 ] \
116+ + features [b ][c ][idx [b ][n ][2 ]] * weight [b ][n ][2 ]
117117 return output
118118
119+
119120def three_interpolate_backward_gloden (grad_output , idx , weight , features ):
120121 bs , cs , ns = grad_output .shape
121122 ms = features .shape [2 ]
@@ -124,7 +125,7 @@ def three_interpolate_backward_gloden(grad_output, idx, weight, features):
124125 if dtype == np .float16 :
125126 features = features .astype (np .float32 )
126127 weight = weight .astype (np .float32 )
127-
128+
128129 grad_point = np .zeros ((bs , cs , ms ), dtype = np .float )
129130 for b in range (bs ):
130131 for c in range (cs ):
@@ -137,6 +138,7 @@ def three_interpolate_backward_gloden(grad_output, idx, weight, features):
137138 grad_output [b ][c ][n ] * weight [b ][n ][2 ]
138139 return grad_point
139140
141+
140142def torch_type_trans (dtype ):
141143 if dtype == torch .half :
142144 return np .float16
@@ -145,27 +147,36 @@ def torch_type_trans(dtype):
145147 else :
146148 return np .float64
147149
150+
148151@pytest .mark .parametrize ('dtype' , [
149152 torch .half ,
150153 torch .float
151154])
152155@pytest .mark .parametrize ('device' , [
153- (2 ,5 ,6 ,6 ),
154- (10 ,10 ,10 ,10 ),
155- (20 ,21 ,13 ,4 ),
156- (2 ,10 ,2 ,18 ),
157- (10 ,602 ,910 ,200 ),
158- (600 ,100 ,300 ,101 )
156+ pytest .param (
157+ 'npu' ,
158+ marks = pytest .mark .skipif (
159+ not IS_NPU_AVAILABLE , reason = 'requires NPU support' ))
160+ ])
161+ @pytest .mark .parametrize ('shape' , [
162+ (2 , 5 , 6 , 6 ),
163+ (10 , 10 , 10 , 10 ),
164+ (20 , 21 , 13 , 4 ),
165+ (2 , 10 , 2 , 18 ),
166+ (10 , 602 , 910 , 200 ),
167+ (600 , 100 , 300 , 101 )
159168])
160169def test_three_interpolate_npu_dynamic_shape (dtype , device , shape ):
161170 bs = shape [0 ]
162171 cs = shape [1 ]
163172 ms = shape [2 ]
164173 ns = shape [3 ]
165174
166- features = np .random .uniform (- 10.0 , 10.0 , (bs , cs , ms ).astype (torch_type_trans (dtype )))
175+ features = np .random .uniform (- 10.0 , 10.0 ,
176+ (bs , cs , ms ).astype (torch_type_trans (dtype )))
167177 idx = np .random .uniform (0 , ms , size = (bs , ns , 3 ), dtype = np .int32 )
168- weight = np .random .uniform (- 10.0 , 10.0 (bs , ns , 3 )).astype (torch_type_trans (dtype ))
178+ weight = np .random .uniform (- 10.0 , 10.0 (bs , ns , 3 )
179+ ).astype (torch_type_trans (dtype ))
169180
170181 features_npu = torch .tensor (features , dtype = dtype ).to (device )
171182 idx_npu = torch .tensor (idx , dtype = torch .int32 ).to (device )
0 commit comments