Skip to content

Commit 3a9419c

Browse files
chunhuanMengCopilottoyxu
authored
Refactor Triangular Operation Kernels (triu/tril) for XPU in SYCL (#1735)
Fix pytorch/pytorch#155651 Refactor the triangular operations kernels (`triu` and `tril`) for XPU in SYCL. Key changes include introducing support for element-wise operations within threads, improving memory access patterns, and simplifying the handling of tensor metadata. Additionally, redundant code has been removed for cleaner implementation. ### Refactoring: * **Refactored `ApplyTriuTrilKernelFunctor`**: - Added support for processing multiple elements per thread (elements_per_thread) and optimized memory access patterns (e.g., load load load load, compute compute compute compute, store store store store) to improve memory throughput. - Refactored offset computation to handle multi-dimensional tensors more effectively using `TensorInfo`. - Introduced `BOOL_SWITCH` macro for conditional execution based on whether the operation is in-place. * **Improved handling of tensor metadata**: - Replaced manual stride and size calculations with `TensorInfo` for cleaner and more maintainable code. ### Code Cleanup: * **Removed redundant return statements**: - Eliminated commented-out `return result;` lines in `tril_kernel` and `triu_kernel` functions for clarity. --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Yutao Xu <[email protected]>
1 parent ea48cc4 commit 3a9419c

File tree

2 files changed

+148
-91
lines changed

2 files changed

+148
-91
lines changed

src/ATen/native/xpu/sycl/TriangularOpsKernels.cpp

Lines changed: 135 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,162 @@
66
#include <ATen/native/CanUse32BitIndexMath.h>
77
#include <ATen/native/Resize.h>
88
#include <comm/SYCLContext.h>
9+
#include <comm/TensorInfo.h>
910

1011
#include <ATen/native/xpu/sycl/TriangularOpsKernels.h>
1112

13+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
14+
[&] { \
15+
if (COND) { \
16+
constexpr static bool CONST_NAME = true; \
17+
return __VA_ARGS__(); \
18+
} else { \
19+
constexpr static bool CONST_NAME = false; \
20+
return __VA_ARGS__(); \
21+
} \
22+
}()
23+
1224
namespace at::native::xpu {
1325

1426
using namespace at::xpu;
1527

16-
template <typename scalar_t, typename IndexType, bool upper>
28+
template <
29+
typename scalar_t,
30+
typename IndexType,
31+
bool upper,
32+
int elements_per_thread,
33+
bool inplace>
1734
struct ApplyTriuTrilKernelFunctor {
1835
void operator()(sycl::nd_item<1> item) const {
19-
for (size_t linearIndex = item.get_global_id(0); linearIndex < (size_t)N;
20-
linearIndex += item.get_global_range()[0]) {
21-
IndexType batch_id = linearIndex / (self_size_0 * self_size_1);
22-
IndexType row = (linearIndex % (self_size_0 * self_size_1)) / self_size_1;
23-
IndexType col = (linearIndex % (self_size_0 * self_size_1)) % self_size_1;
24-
25-
IndexType src_index =
26-
batch_id * self_stride + row * self_stride_0 + col * self_stride_1;
27-
IndexType tgt_index = batch_id * result_stride + row * result_stride_0 +
28-
col * result_stride_1;
29-
30-
bool mask = upper ? (col - row >= k) : (col - row <= k);
31-
result_ptr[tgt_index] = mask ? self_ptr[src_index] : scalar_t(0);
36+
IndexType linear_idx = item.get_global_id(0) * elements_per_thread;
37+
if (linear_idx >= N_padded_) {
38+
return;
39+
}
40+
auto dims = self_info_.dims;
41+
42+
// Compute column index amd row index
43+
IndexType col = linear_idx % last_dim_padded_;
44+
linear_idx /= last_dim_padded_;
45+
IndexType row = linear_idx % self_info_.sizes[dims - 2];
46+
47+
if constexpr (inplace) {
48+
bool mask_all_true =
49+
upper ? (col - row >= k_) : (col + elements_per_thread - row <= k_);
50+
if (mask_all_true)
51+
return;
52+
}
53+
54+
// Compute offset
55+
IndexType self_offset = 0, result_offset = 0;
56+
self_offset += self_info_.strides[dims - 1] * col;
57+
result_offset += result_info_.strides[dims - 1] * col;
58+
linear_idx /= self_info_.sizes[dims - 2];
59+
self_offset += self_info_.strides[dims - 2] * row;
60+
result_offset += result_info_.strides[dims - 2] * row;
61+
62+
// Compute remaining offsets
63+
IndexType running_index;
64+
for (int i = dims - 3; i >= 0; --i) {
65+
running_index = linear_idx % self_info_.sizes[i];
66+
linear_idx /= self_info_.sizes[i];
67+
self_offset += running_index * self_info_.strides[i];
68+
result_offset += running_index * result_info_.strides[i];
69+
}
70+
71+
if constexpr (inplace) {
72+
#pragma unroll
73+
for (int i = 0;
74+
i < elements_per_thread && col + i < self_info_.sizes[dims - 1];
75+
i++) {
76+
bool mask = upper ? (col + i - row >= k_) : (col + i - row <= k_);
77+
if (!mask)
78+
result_info_
79+
.data[result_offset + i * result_info_.strides[dims - 1]] =
80+
scalar_t(0);
81+
}
82+
} else {
83+
scalar_t frag[elements_per_thread] = {};
84+
bool has_mask = (upper && col + elements_per_thread - row >= k_) ||
85+
(!upper && col - row <= k_);
86+
if (has_mask) {
87+
#pragma unroll
88+
for (int i = 0;
89+
i < elements_per_thread && col + i < self_info_.sizes[dims - 1];
90+
i++)
91+
frag[i] =
92+
self_info_.data[self_offset + i * self_info_.strides[dims - 1]];
93+
94+
#pragma unroll
95+
for (int i = 0; i < elements_per_thread; i++) {
96+
bool mask = upper ? (col + i - row >= k_) : (col + i - row <= k_);
97+
frag[i] = mask ? frag[i] : scalar_t(0);
98+
}
99+
}
100+
101+
#pragma unroll
102+
for (int i = 0;
103+
i < elements_per_thread && col + i < self_info_.sizes[dims - 1];
104+
i++)
105+
result_info_.data[result_offset + i * result_info_.strides[dims - 1]] =
106+
frag[i];
32107
}
33108
}
34109
ApplyTriuTrilKernelFunctor(
35-
const int64_t k_,
36-
int64_t N_,
37-
IndexType self_size_0_,
38-
IndexType self_size_1_,
39-
IndexType self_stride_,
40-
IndexType self_stride_0_,
41-
IndexType self_stride_1_,
42-
IndexType result_stride_,
43-
IndexType result_stride_0_,
44-
IndexType result_stride_1_,
45-
scalar_t* result_ptr_,
46-
const scalar_t* self_ptr_)
47-
: k(k_),
48-
N(N_),
49-
self_size_0(self_size_0_),
50-
self_size_1(self_size_1_),
51-
self_stride(self_stride_),
52-
self_stride_0(self_stride_0_),
53-
self_stride_1(self_stride_1_),
54-
result_stride(result_stride_),
55-
result_stride_0(result_stride_0_),
56-
result_stride_1(result_stride_1_),
57-
result_ptr(result_ptr_),
58-
self_ptr(self_ptr_) {}
110+
at::xpu::detail::TensorInfo<scalar_t, IndexType> result_info,
111+
at::xpu::detail::TensorInfo<const scalar_t, IndexType> self_info,
112+
const int64_t k,
113+
const int64_t N_padded,
114+
const IndexType last_dim_padded)
115+
: result_info_(result_info),
116+
self_info_(self_info),
117+
k_(k),
118+
N_padded_(N_padded),
119+
last_dim_padded_(last_dim_padded) {}
59120

60121
private:
61-
const int64_t k;
62-
int64_t N;
63-
IndexType self_size_0;
64-
IndexType self_size_1;
65-
IndexType self_stride;
66-
IndexType self_stride_0;
67-
IndexType self_stride_1;
68-
IndexType result_stride;
69-
IndexType result_stride_0;
70-
IndexType result_stride_1;
71-
scalar_t* result_ptr;
72-
const scalar_t* self_ptr;
122+
at::xpu::detail::TensorInfo<scalar_t, IndexType> result_info_;
123+
at::xpu::detail::TensorInfo<const scalar_t, IndexType> self_info_;
124+
const int64_t k_;
125+
const int64_t N_padded_;
126+
const IndexType last_dim_padded_;
73127
};
74128

75129
template <typename scalar_t, typename IndexType, bool upper>
76130
void apply_triu_tril(
77131
const Tensor& result,
78132
const Tensor& self,
79133
const int64_t k) {
80-
auto N = self.numel();
81-
IndexType self_size_0 = (IndexType)self.size(-2);
82-
IndexType self_size_1 = (IndexType)self.size(-1);
83-
IndexType self_stride = (IndexType)(self.dim() > 2 ? self.stride(-3) : 1);
84-
IndexType self_stride_0 = (IndexType)self.stride(-2);
85-
IndexType self_stride_1 = (IndexType)self.stride(-1);
86-
IndexType result_stride =
87-
(IndexType)(result.dim() > 2 ? result.stride(-3) : 1);
88-
IndexType result_stride_0 = (IndexType)result.stride(-2);
89-
IndexType result_stride_1 = (IndexType)result.stride(-1);
90-
91-
scalar_t* result_ptr = result.data_ptr<scalar_t>();
92-
const scalar_t* self_ptr = self.const_data_ptr<scalar_t>();
93-
94-
ApplyTriuTrilKernelFunctor<scalar_t, IndexType, upper> kfn(
95-
k,
96-
N,
97-
self_size_0,
98-
self_size_1,
99-
self_stride,
100-
self_stride_0,
101-
self_stride_1,
102-
result_stride,
103-
result_stride_0,
104-
result_stride_1,
105-
result_ptr,
106-
self_ptr);
107-
108-
int64_t group_size = syclMaxWorkGroupSize(kfn);
109-
auto num_groups = ceil_div(N, group_size);
110-
auto total_items = num_groups * group_size;
111-
auto& queue = getCurrentSYCLQueue();
112-
113-
sycl_kernel_submit(
114-
sycl::range<1>(total_items), sycl::range<1>(group_size), queue, kfn);
134+
constexpr int elements_per_thread =
135+
sizeof(scalar_t) < 8 ? 8 / sizeof(scalar_t) : 1;
136+
auto sizes = self.sizes();
137+
int64_t last_dim_padded =
138+
round_up<int64_t>(sizes.back(), elements_per_thread);
139+
int64_t N_padded =
140+
c10::multiply_integers(sizes.begin(), sizes.end() - 1) * last_dim_padded;
141+
142+
int64_t local_range = syclMaxWorkItemsPerSubSlice();
143+
int64_t global_range =
144+
((N_padded / elements_per_thread + local_range - 1) / local_range) *
145+
local_range;
146+
147+
auto result_info =
148+
at::xpu::detail::getTensorInfo<scalar_t, IndexType>(result);
149+
auto self_info =
150+
at::xpu::detail::getTensorInfo<const scalar_t, IndexType>(self);
151+
BOOL_SWITCH(self.is_same(result), inplace, [&] {
152+
ApplyTriuTrilKernelFunctor<
153+
scalar_t,
154+
IndexType,
155+
upper,
156+
elements_per_thread,
157+
inplace>
158+
kfn(result_info, self_info, k, N_padded, last_dim_padded);
159+
sycl_kernel_submit(
160+
sycl::range<1>(global_range),
161+
sycl::range<1>(local_range),
162+
getCurrentSYCLQueue(),
163+
kfn);
164+
});
115165
}
116166

117167
#define TRIU_TRIL_LAMBDA(upper) \
@@ -128,7 +178,6 @@ void tril_kernel(const Tensor& result, const Tensor& self, int64_t k) {
128178
result.resize_as_(self);
129179
}
130180
if (self.numel() == 0) {
131-
// return result;
132181
return;
133182
}
134183

@@ -140,16 +189,13 @@ void tril_kernel(const Tensor& result, const Tensor& self, int64_t k) {
140189
self.scalar_type(),
141190
"tril_xpu",
142191
TRIU_TRIL_LAMBDA(false));
143-
144-
// return result;
145192
}
146193

147194
void triu_kernel(const Tensor& result, const Tensor& self, int64_t k) {
148195
if (result.sizes() != self.sizes()) {
149196
result.resize_as_(self);
150197
}
151198
if (self.numel() == 0) {
152-
// return result;
153199
return;
154200
}
155201
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
@@ -160,8 +206,6 @@ void triu_kernel(const Tensor& result, const Tensor& self, int64_t k) {
160206
self.scalar_type(),
161207
"triu_xpu",
162208
TRIU_TRIL_LAMBDA(true));
163-
164-
// return result;
165209
}
166210

167211
} // namespace at::native::xpu

test/regressions/test_tril.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Owner(s): ["module: intel"]
2+
import torch
3+
from torch.testing._internal.common_utils import TestCase
4+
5+
6+
class TestSimpleBinary(TestCase):
7+
def test_tril(self, dtype=torch.bool):
8+
max_seq_length = 131072
9+
with torch.device("xpu"):
10+
a = torch.ones(max_seq_length, max_seq_length, dtype=torch.bool)
11+
causal_mask = torch.tril(a)
12+
torch.xpu.synchronize()
13+
print(torch.xpu.get_device_properties())

0 commit comments

Comments
 (0)