Skip to content

Commit 766a896

Browse files
authored
override sort operators for UT functionality (#4956) (#4967)
* fix example case test_mode in test_torch_mode.py * fix comments * fix ut
1 parent 161b58a commit 766a896

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

csrc/gpu/aten/operators/Sort.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include <core/Memory.h>
22
#include <core/detail/IndexUtils.h>
33
#include <core/detail/TensorInfo.h>
4+
#ifdef USE_OVERRIDE_OP
5+
#include "utils/CustomOperatorRegistration.h"
6+
#endif
47
#include <utils/DPCPP.h>
58
#include "comm/ATDispatch.h"
69
#include "comm/Numerics.h"
@@ -55,6 +58,56 @@ Tensor argsort(const Tensor& self, bool stable, int64_t dim, bool descending) {
5558
return std::get<1>(
5659
sort_out_stable(self, stable, dim, descending, sorted, indices));
5760
}
61+
#ifdef USE_OVERRIDE_OP
62+
std::tuple<at::Tensor, at::Tensor> sort_ipex(
63+
const at::Tensor& self,
64+
int64_t dim,
65+
bool descending) {
66+
return at::AtenIpexTypeXPU::sort(self, dim, descending);
67+
}
68+
69+
std::tuple<Tensor&, Tensor&> sort_values(
70+
const Tensor& input,
71+
int64_t dim,
72+
bool order,
73+
Tensor& sorted,
74+
Tensor& indices) {
75+
return at::AtenIpexTypeXPU::sort_out(input, dim, order, sorted, indices);
76+
}
77+
78+
std::tuple<Tensor, Tensor> sort_stable(
79+
const Tensor& self,
80+
c10::optional<bool> stable,
81+
int64_t dim,
82+
bool descending) {
83+
return at::AtenIpexTypeXPU::sort(self, stable, dim, descending);
84+
}
5885

86+
std::tuple<Tensor&, Tensor&> sort_values_stable(
87+
const Tensor& self,
88+
c10::optional<bool> stable,
89+
int64_t dim,
90+
bool descending,
91+
Tensor& values,
92+
Tensor& indices) {
93+
return at::AtenIpexTypeXPU::sort_out(
94+
self, stable, dim, descending, values, indices);
95+
}
96+
#endif
5997
} // namespace AtenIpexTypeXPU
6098
} // namespace at
99+
100+
#ifdef USE_OVERRIDE_OP
101+
namespace {
102+
103+
IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) {
104+
m.impl("sort", TORCH_FN((&at::AtenIpexTypeXPU::sort_ipex)));
105+
m.impl("sort.stable", TORCH_FN((&at::AtenIpexTypeXPU::sort_stable)));
106+
m.impl("sort.values", TORCH_FN((&at::AtenIpexTypeXPU::sort_values)));
107+
m.impl(
108+
"sort.values_stable",
109+
TORCH_FN((&at::AtenIpexTypeXPU::sort_values_stable)));
110+
}
111+
112+
} // namespace
113+
#endif

scripts/tools/torchgen/yaml/xpu_functions.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ supported:
1818
# - col2im.out
1919
# - im2col
2020
# - im2col.out
21+
# - sort
22+
# - sort.stable
23+
# - sort.values
24+
# - sort.values_stable
2125
################## override below ops due to performance issues
2226
# - convolution_overrideable
2327
# - convolution_backward_overrideable
@@ -585,10 +589,6 @@ supported:
585589
# - softplus_backward.grad_input
586590
# - softshrink.out
587591
# - softshrink_backward.grad_input
588-
- sort
589-
# - sort.stable
590-
- sort.values
591-
# - sort.values_stable
592592
- special_entr.out
593593
- special_erfcx.out
594594
- special_i0e.out

tests/gpu/examples/test_torch_mode.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
from torch.testing._internal.common_utils import TestCase
33
import intel_extension_for_pytorch # noqa
44

5-
import pytest
6-
75
cpu_device = torch.device("cpu")
86
xpu_device = torch.device("xpu")
97
value_range = 30
108

119

1210
class TestTorchMethod(TestCase):
13-
@pytest.mark.skip(reason="PT2.5: Scalars are not equal!")
1411
def test_mode(self):
1512
def mode_test_helper(input_list):
1613
for input_cpu in input_list:

0 commit comments

Comments
 (0)