|
1 | 1 | #include <core/Memory.h> |
2 | 2 | #include <core/detail/IndexUtils.h> |
3 | 3 | #include <core/detail/TensorInfo.h> |
| 4 | +#ifdef USE_OVERRIDE_OP |
| 5 | +#include "utils/CustomOperatorRegistration.h" |
| 6 | +#endif |
4 | 7 | #include <utils/DPCPP.h> |
5 | 8 | #include "comm/ATDispatch.h" |
6 | 9 | #include "comm/Numerics.h" |
@@ -55,6 +58,56 @@ Tensor argsort(const Tensor& self, bool stable, int64_t dim, bool descending) { |
55 | 58 | return std::get<1>( |
56 | 59 | sort_out_stable(self, stable, dim, descending, sorted, indices)); |
57 | 60 | } |
| 61 | +#ifdef USE_OVERRIDE_OP |
| 62 | +std::tuple<at::Tensor, at::Tensor> sort_ipex( |
| 63 | + const at::Tensor& self, |
| 64 | + int64_t dim, |
| 65 | + bool descending) { |
| 66 | + return at::AtenIpexTypeXPU::sort(self, dim, descending); |
| 67 | +} |
| 68 | + |
| 69 | +std::tuple<Tensor&, Tensor&> sort_values( |
| 70 | + const Tensor& input, |
| 71 | + int64_t dim, |
| 72 | + bool order, |
| 73 | + Tensor& sorted, |
| 74 | + Tensor& indices) { |
| 75 | + return at::AtenIpexTypeXPU::sort_out(input, dim, order, sorted, indices); |
| 76 | +} |
| 77 | + |
| 78 | +std::tuple<Tensor, Tensor> sort_stable( |
| 79 | + const Tensor& self, |
| 80 | + c10::optional<bool> stable, |
| 81 | + int64_t dim, |
| 82 | + bool descending) { |
| 83 | + return at::AtenIpexTypeXPU::sort(self, stable, dim, descending); |
| 84 | +} |
58 | 85 |
|
| 86 | +std::tuple<Tensor&, Tensor&> sort_values_stable( |
| 87 | + const Tensor& self, |
| 88 | + c10::optional<bool> stable, |
| 89 | + int64_t dim, |
| 90 | + bool descending, |
| 91 | + Tensor& values, |
| 92 | + Tensor& indices) { |
| 93 | + return at::AtenIpexTypeXPU::sort_out( |
| 94 | + self, stable, dim, descending, values, indices); |
| 95 | +} |
| 96 | +#endif |
59 | 97 | } // namespace AtenIpexTypeXPU |
60 | 98 | } // namespace at |
| 99 | + |
| 100 | +#ifdef USE_OVERRIDE_OP |
| 101 | +namespace { |
| 102 | + |
| 103 | +IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) { |
| 104 | + m.impl("sort", TORCH_FN((&at::AtenIpexTypeXPU::sort_ipex))); |
| 105 | + m.impl("sort.stable", TORCH_FN((&at::AtenIpexTypeXPU::sort_stable))); |
| 106 | + m.impl("sort.values", TORCH_FN((&at::AtenIpexTypeXPU::sort_values))); |
| 107 | + m.impl( |
| 108 | + "sort.values_stable", |
| 109 | + TORCH_FN((&at::AtenIpexTypeXPU::sort_values_stable))); |
| 110 | +} |
| 111 | + |
| 112 | +} // namespace |
| 113 | +#endif |
0 commit comments