6
6
#include < ATen/native/CanUse32BitIndexMath.h>
7
7
#include < ATen/native/Resize.h>
8
8
#include < comm/SYCLContext.h>
9
+ #include < comm/TensorInfo.h>
9
10
10
11
#include < ATen/native/xpu/sycl/TriangularOpsKernels.h>
11
12
13
+ #define BOOL_SWITCH (COND, CONST_NAME, ...) \
14
+ [&] { \
15
+ if (COND) { \
16
+ constexpr static bool CONST_NAME = true ; \
17
+ return __VA_ARGS__ (); \
18
+ } else { \
19
+ constexpr static bool CONST_NAME = false ; \
20
+ return __VA_ARGS__ (); \
21
+ } \
22
+ }()
23
+
12
24
namespace at ::native::xpu {
13
25
14
26
using namespace at ::xpu;
15
27
16
- template <typename scalar_t , typename IndexType, bool upper>
28
+ template <
29
+ typename scalar_t ,
30
+ typename IndexType,
31
+ bool upper,
32
+ int elements_per_thread,
33
+ bool inplace>
17
34
struct ApplyTriuTrilKernelFunctor {
18
35
void operator ()(sycl::nd_item<1 > item) const {
19
- for (size_t linearIndex = item.get_global_id (0 ); linearIndex < (size_t )N;
20
- linearIndex += item.get_global_range ()[0 ]) {
21
- IndexType batch_id = linearIndex / (self_size_0 * self_size_1);
22
- IndexType row = (linearIndex % (self_size_0 * self_size_1)) / self_size_1;
23
- IndexType col = (linearIndex % (self_size_0 * self_size_1)) % self_size_1;
24
-
25
- IndexType src_index =
26
- batch_id * self_stride + row * self_stride_0 + col * self_stride_1;
27
- IndexType tgt_index = batch_id * result_stride + row * result_stride_0 +
28
- col * result_stride_1;
29
-
30
- bool mask = upper ? (col - row >= k) : (col - row <= k);
31
- result_ptr[tgt_index] = mask ? self_ptr[src_index] : scalar_t (0 );
36
+ IndexType linear_idx = item.get_global_id (0 ) * elements_per_thread;
37
+ if (linear_idx >= N_padded_) {
38
+ return ;
39
+ }
40
+ auto dims = self_info_.dims ;
41
+
42
+ // Compute column index amd row index
43
+ IndexType col = linear_idx % last_dim_padded_;
44
+ linear_idx /= last_dim_padded_;
45
+ IndexType row = linear_idx % self_info_.sizes [dims - 2 ];
46
+
47
+ if constexpr (inplace) {
48
+ bool mask_all_true =
49
+ upper ? (col - row >= k_) : (col + elements_per_thread - row <= k_);
50
+ if (mask_all_true)
51
+ return ;
52
+ }
53
+
54
+ // Compute offset
55
+ IndexType self_offset = 0 , result_offset = 0 ;
56
+ self_offset += self_info_.strides [dims - 1 ] * col;
57
+ result_offset += result_info_.strides [dims - 1 ] * col;
58
+ linear_idx /= self_info_.sizes [dims - 2 ];
59
+ self_offset += self_info_.strides [dims - 2 ] * row;
60
+ result_offset += result_info_.strides [dims - 2 ] * row;
61
+
62
+ // Compute remaining offsets
63
+ IndexType running_index;
64
+ for (int i = dims - 3 ; i >= 0 ; --i) {
65
+ running_index = linear_idx % self_info_.sizes [i];
66
+ linear_idx /= self_info_.sizes [i];
67
+ self_offset += running_index * self_info_.strides [i];
68
+ result_offset += running_index * result_info_.strides [i];
69
+ }
70
+
71
+ if constexpr (inplace) {
72
+ #pragma unroll
73
+ for (int i = 0 ;
74
+ i < elements_per_thread && col + i < self_info_.sizes [dims - 1 ];
75
+ i++) {
76
+ bool mask = upper ? (col + i - row >= k_) : (col + i - row <= k_);
77
+ if (!mask)
78
+ result_info_
79
+ .data [result_offset + i * result_info_.strides [dims - 1 ]] =
80
+ scalar_t (0 );
81
+ }
82
+ } else {
83
+ scalar_t frag[elements_per_thread] = {};
84
+ bool has_mask = (upper && col + elements_per_thread - row >= k_) ||
85
+ (!upper && col - row <= k_);
86
+ if (has_mask) {
87
+ #pragma unroll
88
+ for (int i = 0 ;
89
+ i < elements_per_thread && col + i < self_info_.sizes [dims - 1 ];
90
+ i++)
91
+ frag[i] =
92
+ self_info_.data [self_offset + i * self_info_.strides [dims - 1 ]];
93
+
94
+ #pragma unroll
95
+ for (int i = 0 ; i < elements_per_thread; i++) {
96
+ bool mask = upper ? (col + i - row >= k_) : (col + i - row <= k_);
97
+ frag[i] = mask ? frag[i] : scalar_t (0 );
98
+ }
99
+ }
100
+
101
+ #pragma unroll
102
+ for (int i = 0 ;
103
+ i < elements_per_thread && col + i < self_info_.sizes [dims - 1 ];
104
+ i++)
105
+ result_info_.data [result_offset + i * result_info_.strides [dims - 1 ]] =
106
+ frag[i];
32
107
}
33
108
}
34
109
ApplyTriuTrilKernelFunctor (
35
- const int64_t k_,
36
- int64_t N_,
37
- IndexType self_size_0_,
38
- IndexType self_size_1_,
39
- IndexType self_stride_,
40
- IndexType self_stride_0_,
41
- IndexType self_stride_1_,
42
- IndexType result_stride_,
43
- IndexType result_stride_0_,
44
- IndexType result_stride_1_,
45
- scalar_t * result_ptr_,
46
- const scalar_t * self_ptr_)
47
- : k(k_),
48
- N (N_),
49
- self_size_0(self_size_0_),
50
- self_size_1(self_size_1_),
51
- self_stride(self_stride_),
52
- self_stride_0(self_stride_0_),
53
- self_stride_1(self_stride_1_),
54
- result_stride(result_stride_),
55
- result_stride_0(result_stride_0_),
56
- result_stride_1(result_stride_1_),
57
- result_ptr(result_ptr_),
58
- self_ptr(self_ptr_) {}
110
+ at::xpu::detail::TensorInfo<scalar_t , IndexType> result_info,
111
+ at::xpu::detail::TensorInfo<const scalar_t , IndexType> self_info,
112
+ const int64_t k,
113
+ const int64_t N_padded,
114
+ const IndexType last_dim_padded)
115
+ : result_info_(result_info),
116
+ self_info_ (self_info),
117
+ k_(k),
118
+ N_padded_(N_padded),
119
+ last_dim_padded_(last_dim_padded) {}
59
120
60
121
private:
61
- const int64_t k;
62
- int64_t N;
63
- IndexType self_size_0;
64
- IndexType self_size_1;
65
- IndexType self_stride;
66
- IndexType self_stride_0;
67
- IndexType self_stride_1;
68
- IndexType result_stride;
69
- IndexType result_stride_0;
70
- IndexType result_stride_1;
71
- scalar_t * result_ptr;
72
- const scalar_t * self_ptr;
122
+ at::xpu::detail::TensorInfo<scalar_t , IndexType> result_info_;
123
+ at::xpu::detail::TensorInfo<const scalar_t , IndexType> self_info_;
124
+ const int64_t k_;
125
+ const int64_t N_padded_;
126
+ const IndexType last_dim_padded_;
73
127
};
74
128
75
129
template <typename scalar_t , typename IndexType, bool upper>
76
130
void apply_triu_tril (
77
131
const Tensor& result,
78
132
const Tensor& self,
79
133
const int64_t k) {
80
- auto N = self.numel ();
81
- IndexType self_size_0 = (IndexType)self.size (-2 );
82
- IndexType self_size_1 = (IndexType)self.size (-1 );
83
- IndexType self_stride = (IndexType)(self.dim () > 2 ? self.stride (-3 ) : 1 );
84
- IndexType self_stride_0 = (IndexType)self.stride (-2 );
85
- IndexType self_stride_1 = (IndexType)self.stride (-1 );
86
- IndexType result_stride =
87
- (IndexType)(result.dim () > 2 ? result.stride (-3 ) : 1 );
88
- IndexType result_stride_0 = (IndexType)result.stride (-2 );
89
- IndexType result_stride_1 = (IndexType)result.stride (-1 );
90
-
91
- scalar_t * result_ptr = result.data_ptr <scalar_t >();
92
- const scalar_t * self_ptr = self.const_data_ptr <scalar_t >();
93
-
94
- ApplyTriuTrilKernelFunctor<scalar_t , IndexType, upper> kfn (
95
- k,
96
- N,
97
- self_size_0,
98
- self_size_1,
99
- self_stride,
100
- self_stride_0,
101
- self_stride_1,
102
- result_stride,
103
- result_stride_0,
104
- result_stride_1,
105
- result_ptr,
106
- self_ptr);
107
-
108
- int64_t group_size = syclMaxWorkGroupSize (kfn);
109
- auto num_groups = ceil_div (N, group_size);
110
- auto total_items = num_groups * group_size;
111
- auto & queue = getCurrentSYCLQueue ();
112
-
113
- sycl_kernel_submit (
114
- sycl::range<1 >(total_items), sycl::range<1 >(group_size), queue, kfn);
134
+ constexpr int elements_per_thread =
135
+ sizeof (scalar_t ) < 8 ? 8 / sizeof (scalar_t ) : 1 ;
136
+ auto sizes = self.sizes ();
137
+ int64_t last_dim_padded =
138
+ round_up<int64_t >(sizes.back (), elements_per_thread);
139
+ int64_t N_padded =
140
+ c10::multiply_integers (sizes.begin (), sizes.end () - 1 ) * last_dim_padded;
141
+
142
+ int64_t local_range = syclMaxWorkItemsPerSubSlice ();
143
+ int64_t global_range =
144
+ ((N_padded / elements_per_thread + local_range - 1 ) / local_range) *
145
+ local_range;
146
+
147
+ auto result_info =
148
+ at::xpu::detail::getTensorInfo<scalar_t , IndexType>(result);
149
+ auto self_info =
150
+ at::xpu::detail::getTensorInfo<const scalar_t , IndexType>(self);
151
+ BOOL_SWITCH (self.is_same (result), inplace, [&] {
152
+ ApplyTriuTrilKernelFunctor<
153
+ scalar_t ,
154
+ IndexType,
155
+ upper,
156
+ elements_per_thread,
157
+ inplace>
158
+ kfn (result_info, self_info, k, N_padded, last_dim_padded);
159
+ sycl_kernel_submit (
160
+ sycl::range<1 >(global_range),
161
+ sycl::range<1 >(local_range),
162
+ getCurrentSYCLQueue (),
163
+ kfn);
164
+ });
115
165
}
116
166
117
167
#define TRIU_TRIL_LAMBDA (upper ) \
@@ -128,7 +178,6 @@ void tril_kernel(const Tensor& result, const Tensor& self, int64_t k) {
128
178
result.resize_as_ (self);
129
179
}
130
180
if (self.numel () == 0 ) {
131
- // return result;
132
181
return ;
133
182
}
134
183
@@ -140,16 +189,13 @@ void tril_kernel(const Tensor& result, const Tensor& self, int64_t k) {
140
189
self.scalar_type (),
141
190
" tril_xpu" ,
142
191
TRIU_TRIL_LAMBDA (false ));
143
-
144
- // return result;
145
192
}
146
193
147
194
void triu_kernel (const Tensor& result, const Tensor& self, int64_t k) {
148
195
if (result.sizes () != self.sizes ()) {
149
196
result.resize_as_ (self);
150
197
}
151
198
if (self.numel () == 0 ) {
152
- // return result;
153
199
return ;
154
200
}
155
201
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4 (
@@ -160,8 +206,6 @@ void triu_kernel(const Tensor& result, const Tensor& self, int64_t k) {
160
206
self.scalar_type (),
161
207
" triu_xpu" ,
162
208
TRIU_TRIL_LAMBDA (true ));
163
-
164
- // return result;
165
209
}
166
210
167
211
} // namespace at::native::xpu
0 commit comments