@@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device):
7272 ], [
7373 2.2060e-01 , 3.4110e-01 , 3.4110e-01 , 2.2060e-01 , 2.2060e-01 , 2.1380e-01
7474 ]],
75- [[
76- 8.1773e-01 , 9.5440e-01 , 2.4532e+00 ,
77- 8.1773e-01 , 8.1773e-01 , 1.1359e+00
78- ],
79- [
80- 8.4689e-01 , 1.9176e+00 , 1.4715e+00 ,
81- 8.4689e-01 , 8.4689e-01 , 1.3079e+00
82- ],
83- [
84- 6.9473e-01 , 2.7440e-01 , 2.0842e+00 ,
85- 6.9473e-01 , 6.9473e-01 , 7.8619e-01
86- ],
87- [
88- 7.6789e-01 , 1.5063e+00 , 1.6209e+00 ,
89- 7.6789e-01 , 7.6789e-01 , 1.1562e+00
90- ],
91- [
92- 3.8760e-01 , 1.0300e-02 , 8.3569e-09 ,
93- 3.8760e-01 , 3.8760e-01 , 1.9723e-01
94- ]]],
95- dtype = dtype ,
96- device = device )
75+ [[
76+ 8.1773e-01 , 9.5440e-01 , 2.4532e+00 ,
77+ 8.1773e-01 , 8.1773e-01 , 1.1359e+00
78+ ],
79+ [
80+ 8.4689e-01 , 1.9176e+00 , 1.4715e+00 ,
81+ 8.4689e-01 , 8.4689e-01 , 1.3079e+00
82+ ],
83+ [
84+ 6.9473e-01 , 2.7440e-01 , 2.0842e+00 ,
85+ 6.9473e-01 , 6.9473e-01 , 7.8619e-01
86+ ],
87+ [
88+ 7.6789e-01 , 1.5063e+00 , 1.6209e+00 ,
89+ 7.6789e-01 , 7.6789e-01 , 1.1562e+00
90+ ],
91+ [
92+ 3.8760e-01 , 1.0300e-02 , 8.3569e-09 ,
93+ 3.8760e-01 , 3.8760e-01 , 1.9723e-01
94+ ]]],
95+ dtype = dtype ,
96+ device = device )
9797
9898 assert torch .allclose (output , expected_output , 1e-3 , 1e-4 )
9999
@@ -148,24 +148,16 @@ def torch_type_trans(dtype):
148148 return np .float64
149149
150150
151- @pytest .mark .parametrize ('dtype' , [
152- torch .half ,
153- torch .float
154- ])
151+ @pytest .mark .parametrize ('dtype' , [torch .half , torch .float ])
155152@pytest .mark .parametrize ('device' , [
156153 pytest .param (
157154 'npu' ,
158155 marks = pytest .mark .skipif (
159156 not IS_NPU_AVAILABLE , reason = 'requires NPU support' ))
160157])
161- @pytest .mark .parametrize ('shape' , [
162- (2 , 5 , 6 , 6 ),
163- (10 , 10 , 10 , 10 ),
164- (20 , 21 , 13 , 4 ),
165- (2 , 10 , 2 , 18 ),
166- (10 , 602 , 910 , 200 ),
167- (600 , 100 , 300 , 101 )
168- ])
158+ @pytest .mark .parametrize ('shape' , [(2 , 5 , 6 , 6 ), (10 , 10 , 10 , 10 ),
159+ (20 , 21 , 13 , 4 ), (2 , 10 , 2 , 18 ),
160+ (10 , 602 , 910 , 200 ), (600 , 100 , 300 , 101 )])
169161def test_three_interpolate_npu_dynamic_shape (dtype , device , shape ):
170162 bs = shape [0 ]
171163 cs = shape [1 ]
@@ -175,13 +167,14 @@ def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
175167 features = np .random .uniform (- 10.0 , 10.0 ,
176168 (bs , cs , ms ).astype (torch_type_trans (dtype )))
177169 idx = np .random .uniform (0 , ms , size = (bs , ns , 3 ), dtype = np .int32 )
178- weight = np .random .uniform (- 10.0 , 10.0 (bs , ns , 3 )
179- ).astype (torch_type_trans (dtype ))
170+ weight = np .random .uniform (- 10.0 ,
171+ 10.0 (bs , ns ,
172+ 3 )).astype (torch_type_trans (dtype ))
180173
181174 features_npu = torch .tensor (features , dtype = dtype ).to (device )
182175 idx_npu = torch .tensor (idx , dtype = torch .int32 ).to (device )
183176 weight_npu = torch .tensor (weight , dtype = dtype ).to (device )
184177
185178 expected_output = three_interpolate_forward_gloden (features , idx , weight )
186179 output = three_interpolate (features_npu , idx_npu , weight_npu )
187- assert np .allclose (output .cpu ().numpy (), expected_output , 1e-3 , 1e-4 )
180+ assert np .allclose (output .cpu ().numpy (), expected_output , 1e-3 , 1e-4 )
0 commit comments