Skip to content

Commit 0e74442

Browse files
authored
Override Ops for failed UTs caused by torch-xpu-ops (#4961) (#4981)
* fix UTs * fix format * add device check
1 parent 1d2371c commit 0e74442

File tree

15 files changed

+148
-48
lines changed

15 files changed

+148
-48
lines changed

csrc/gpu/aten/operators/BatchNorm.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5208,6 +5208,90 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce(
52085208
grad_output, input, mean, invstd, weight_opt, input_g, weight_g, bias_g);
52095209
}
52105210

5211+
#ifdef USE_OVERRIDE_OP
5212+
// Rename below functions because they have overload with the same name
5213+
// and can't be registered.
5214+
std::tuple<Tensor, Tensor, Tensor> _native_batch_norm_legit_(
5215+
const Tensor& self,
5216+
const c10::optional<Tensor>& weight_opt,
5217+
const c10::optional<Tensor>& bias_opt,
5218+
Tensor& running_mean,
5219+
Tensor& running_var,
5220+
bool train,
5221+
double momentum,
5222+
double epsilon) {
5223+
return at::AtenIpexTypeXPU::_native_batch_norm_legit(
5224+
self,
5225+
weight_opt,
5226+
bias_opt,
5227+
running_mean,
5228+
running_var,
5229+
train,
5230+
momentum,
5231+
epsilon);
5232+
}
5233+
5234+
std::tuple<Tensor, Tensor, Tensor> _native_batch_norm_legit_no_state(
5235+
const Tensor& self,
5236+
const c10::optional<Tensor>& weight_opt,
5237+
const c10::optional<Tensor>& bias_opt,
5238+
bool train,
5239+
double momentum,
5240+
double epsilon) {
5241+
return at::AtenIpexTypeXPU::_native_batch_norm_legit(
5242+
self, weight_opt, bias_opt, train, momentum, epsilon);
5243+
}
5244+
5245+
std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_out_(
5246+
const Tensor& self,
5247+
const c10::optional<Tensor>& weight_opt,
5248+
const c10::optional<Tensor>& bias_opt,
5249+
Tensor& running_mean,
5250+
Tensor& running_var,
5251+
bool train,
5252+
double momentum,
5253+
double epsilon,
5254+
Tensor& output,
5255+
Tensor& save_mean,
5256+
Tensor& save_invstd) {
5257+
return at::AtenIpexTypeXPU::_native_batch_norm_legit_out(
5258+
self,
5259+
weight_opt,
5260+
bias_opt,
5261+
running_mean,
5262+
running_var,
5263+
train,
5264+
momentum,
5265+
epsilon,
5266+
output,
5267+
save_mean,
5268+
save_invstd);
5269+
}
5270+
5271+
std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_state_out(
5272+
const Tensor& self,
5273+
const c10::optional<Tensor>& weight_opt,
5274+
const c10::optional<Tensor>& bias_opt,
5275+
bool train,
5276+
double momentum,
5277+
double epsilon,
5278+
Tensor& output,
5279+
Tensor& save_mean,
5280+
Tensor& save_invstd) {
5281+
return at::AtenIpexTypeXPU::_native_batch_norm_legit_out(
5282+
self,
5283+
weight_opt,
5284+
bias_opt,
5285+
train,
5286+
momentum,
5287+
epsilon,
5288+
output,
5289+
save_mean,
5290+
save_invstd);
5291+
}
5292+
5293+
#endif
5294+
52115295
} // namespace AtenIpexTypeXPU
52125296
} // namespace at
52135297

@@ -5223,6 +5307,18 @@ IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) {
52235307
m.impl(
52245308
"native_batch_norm_backward",
52255309
TORCH_FN((&at::AtenIpexTypeXPU::native_batch_norm_backward)));
5310+
m.impl(
5311+
"_native_batch_norm_legit",
5312+
TORCH_FN((&at::AtenIpexTypeXPU::_native_batch_norm_legit_)));
5313+
m.impl(
5314+
"_native_batch_norm_legit.out",
5315+
TORCH_FN((&at::AtenIpexTypeXPU::_native_batch_norm_legit_out_)));
5316+
m.impl(
5317+
"_native_batch_norm_legit.no_stats",
5318+
TORCH_FN((&at::AtenIpexTypeXPU::_native_batch_norm_legit_no_state)));
5319+
m.impl(
5320+
"_native_batch_norm_legit.no_stats_out",
5321+
TORCH_FN((&at::AtenIpexTypeXPU::_native_batch_norm_legit_no_state_out)));
52265322
}
52275323

52285324
} // namespace

csrc/gpu/aten/operators/EmbeddingBag.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include <core/Memory.h>
55
#include <runtime/Utils.h>
66
#include <torch/torch.h>
7+
#ifdef USE_OVERRIDE_OP
8+
#include "utils/CustomOperatorRegistration.h"
9+
#endif
710
#include <utils/DPCPP.h>
811

912
#include "BitonicMergeSort.h"
@@ -1294,3 +1297,16 @@ Tensor _embedding_bag_per_sample_weights_backward(
12941297

12951298
} // namespace AtenIpexTypeXPU
12961299
} // namespace at
1300+
1301+
#ifdef USE_OVERRIDE_OP
1302+
namespace {
1303+
1304+
IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) {
1305+
m.impl("_embedding_bag", TORCH_FN((&at::AtenIpexTypeXPU::_embedding_bag)));
1306+
m.impl(
1307+
"_embedding_bag_forward_only",
1308+
TORCH_FN((&at::AtenIpexTypeXPU::_embedding_bag_forward_only)));
1309+
}
1310+
1311+
} // namespace
1312+
#endif

csrc/gpu/aten/operators/GatedLinearUnit.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#include <ATen/OpMathType.h>
44
#include <ATen/TensorUtils.h>
55
#include <runtime/Utils.h>
6+
#ifdef USE_OVERRIDE_OP
7+
#include "utils/CustomOperatorRegistration.h"
8+
#endif
69
#include <utils/DPCPP.h>
710

811
#include "Loops.h"
@@ -208,3 +211,16 @@ Tensor glu_backward_jvp(
208211

209212
} // namespace AtenIpexTypeXPU
210213
} // namespace at
214+
215+
#ifdef USE_OVERRIDE_OP
216+
namespace {
217+
218+
IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) {
219+
m.impl("glu_backward", TORCH_FN((&at::AtenIpexTypeXPU::glu_backward)));
220+
m.impl(
221+
"glu_backward.grad_input",
222+
TORCH_FN((&at::AtenIpexTypeXPU::glu_backward_out)));
223+
}
224+
225+
} // namespace
226+
#endif

csrc/gpu/aten/operators/Indexing.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,9 @@ Tensor& index_select_out(
14391439
int64_t dim,
14401440
const Tensor& index,
14411441
Tensor& out) {
1442+
TORCH_CHECK(self.is_xpu(), "self must be a XPU tensor.");
1443+
TORCH_CHECK(out.is_xpu(), "out must be a XPU tensor.");
1444+
14421445
IPEX_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(
14431446
at::ScalarType::Half,
14441447
at::ScalarType::BFloat16,
@@ -2334,6 +2337,9 @@ Tensor& index_out(
23342337
IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) {
23352338
m.impl(
23362339
"_index_put_impl_", TORCH_FN((&at::AtenIpexTypeXPU::_index_put_impl_)));
2340+
m.impl("index_select", TORCH_FN((&at::AtenIpexTypeXPU::index_select)));
2341+
m.impl(
2342+
"index_select.out", TORCH_FN((&at::AtenIpexTypeXPU::index_select_out)));
23372343
m.impl("nonzero", TORCH_FN((&at::AtenIpexTypeXPU::nonzero)));
23382344
m.impl("nonzero.out", TORCH_FN((&at::AtenIpexTypeXPU::nonzero_out)));
23392345
}

scripts/tools/torchgen/yaml/xpu_functions.yaml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,20 @@ supported:
1818
# - col2im.out
1919
# - im2col
2020
# - im2col.out
21-
# - sort
22-
# - sort.stable
23-
# - sort.values
24-
# - sort.values_stable
21+
# - _embedding_bag
22+
# - _embedding_bag_forward_only
23+
# - _native_batch_norm_legit
24+
# - _native_batch_norm_legit.out
25+
# - _native_batch_norm_legit.no_stats
26+
# - _native_batch_norm_legit.no_stats_out
27+
# - glu_backward
28+
# - glu_backward.grad_input
29+
# - index_select
30+
# - index_select.out
31+
# - sort
32+
# - sort.stable
33+
# - sort.values
34+
# - sort.values_stable
2535
################## override below ops due to performance issues
2636
# - convolution_overrideable
2737
# - convolution_backward_overrideable
@@ -82,9 +92,7 @@ supported:
8292
# - cumsum.out # newly added
8393
- _dirichlet_grad
8494
# - _efficientzerotensor
85-
# - _embedding_bag
8695
- _embedding_bag_dense_backward
87-
# - _embedding_bag_forward_only
8896
- _embedding_bag_per_sample_weights_backward
8997
- _empty_affine_quantized
9098
- _empty_per_channel_affine_quantized
@@ -191,10 +199,6 @@ supported:
191199
# - batch_norm_stats
192200
- batch_norm_stats.out
193201
# - batch_norm_update_stats
194-
# - _native_batch_norm_legit
195-
# - _native_batch_norm_legit.out
196-
# - _native_batch_norm_legit.no_stats
197-
# - _native_batch_norm_legit.no_stats_out
198202
# - batch_norm_backward_elemt
199203
# - batch_norm_backward_reduce
200204
# - bernoulli_.Tensor
@@ -316,8 +320,6 @@ supported:
316320
# - logit.out
317321
# - glu
318322
# - glu.out
319-
# - glu_backward
320-
# - glu_backward.grad_input
321323
- glu_backward_jvp
322324
- glu_jvp
323325
# - gt.Scalar
@@ -355,8 +357,6 @@ supported:
355357
- _unsafe_index.Tensor
356358
# - index_fill_.int_Scalar
357359
# - index_fill_.int_Tensor
358-
# - index_select
359-
# - index_select.out
360360
# - index_add.out
361361
# - inverse
362362
# - inverse.out

tests/gpu/examples/test_batch_norm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -593,9 +593,6 @@ def test_batch_norm_update_stats_simple(self):
593593
self.assertEqual(save_mean_cpu, save_mean_dpcpp.to(cpu_device))
594594
self.assertEqual(save_var_cpu, save_var_dpcpp.to(cpu_device))
595595

596-
@pytest.mark.skip(
597-
reason="PT2.5: TensorAccessor expected 1 dims but tensor has 4",
598-
)
599596
def test_batch_norm_legit_simple(self):
600597
input_cpu = torch.randn(1, 2, 3, 3, dtype=torch.float, device=cpu_device)
601598
n_input = input_cpu.size(1)

tests/gpu/examples/test_cat_array.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ def test_cat_block_layout(self, dtype=torch.float):
135135
@pytest.mark.skipif(
136136
torch.xpu.device_count() == 1, reason="doesn't support with one device"
137137
)
138-
@pytest.mark.skip(
139-
reason="PT2.5: Native API failed. Native API returns: -36 (PI_ERROR_INVALID_QUEUE) -36 (PI_ERROR_INVALID_QUEUE)",
140-
)
141138
def test_cat_multi_device(self, dtype=torch.float):
142139
x_cpu1 = torch.randn([1, 2, 28, 28], device=cpu_device)
143140
x_cpu2 = torch.randn([1, 2, 28, 28], device=cpu_device)

tests/gpu/examples/test_conv.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,6 @@ def test_group_conv3d_channels_last(self, dtype=torch.float):
639639
not torch.xpu.has_channels_last_1d() or torch.xpu.using_onednn_layout(),
640640
reason="doesn't enable channels last 1d or channels last does not support onednn block format",
641641
)
642-
@pytest.mark.skip(reason="PT2.5: Tensor-likes are not close!")
643642
def test_channels_last_1d_fwd(self, dtype=torch.float):
644643
shapes = [
645644
(2, 2, 3),
@@ -708,7 +707,6 @@ def test_channels_last_1d_fwd(self, dtype=torch.float):
708707
not torch.xpu.has_channels_last_1d() or torch.xpu.using_onednn_layout(),
709708
reason="doesn't enable channels last 1d or channels last does not support onednn block format",
710709
)
711-
@pytest.mark.skip(reason="PT2.5: Tensor-likes are not close!")
712710
def test_channels_last_1d_bwd(self, dtype=torch.float):
713711
shapes = [
714712
(1, 7, 15000),
@@ -978,7 +976,6 @@ def test_conv2d_bia_bf16_input_bf16_bia(self, dtype=torch.float):
978976
not torch.xpu.has_channels_last_1d() or torch.xpu.using_onednn_layout(),
979977
reason="doesn't enable channels last 1d or channels last does not support onednn block format",
980978
)
981-
@pytest.mark.skip(reason="PT2.5: Tensor-likes are not close!")
982979
def test_channels_last_1d_bwd_no_grad(self, dtype=torch.float):
983980
shapes = [
984981
(1, 7, 15000),

tests/gpu/examples/test_embedding_bag.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ def test_embedding_bag_all(self, dtype=torch.float32):
6161
rtol=1e-5,
6262
)
6363

64-
@pytest.mark.skip(
65-
reason="PT2.5: Assertion `vec_idx < num_row` failed",
66-
)
6764
def test_embeddingbag_out_of_bounds(self):
6865
stderr = TestCase.runWithPytorchAPIUsageStderr(
6966
f"""\

tests/gpu/examples/test_fp8_index_select.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,8 @@
55
cast_to_fp8,
66
)
77

8-
import pytest
9-
108

119
class TestTorchMethod(TestCase):
12-
@pytest.mark.skip(
13-
reason="PT2.5: 'index_select_xpu' not implemented for 'Float8_e4m3fn'"
14-
)
1510
def test_index_select(self, dtype=torch.float):
1611
dim_size = 10
1712
dims = 3

0 commit comments

Comments
 (0)