Skip to content

Commit b1582e1

Browse files
hjheetoyxu
andauthored
Add aten::_native_multi_head_attention relevant operators (#892)
- _native_multi_head_attention - _transform_bias_rescale_qkv - _fused_sdp_choice --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent 597d8e4 commit b1582e1

File tree

14 files changed

+1057
-14
lines changed

14 files changed

+1057
-14
lines changed

.github/scripts/apply_torch_pr.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
"https://github.com/pytorch/pytorch/pull/126516",
1414
# Modify the tolerance level in TIMM benchmark
1515
"https://github.com/pytorch/pytorch/pull/129735",
16-
# [Intel GPU] Allow XPU device in cdist and pdist operators
17-
"https://github.com/pytorch/pytorch/pull/138441",
1816
]
1917
)
2018
parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[])

src/ATen/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
file(GLOB xpu_h "xpu/*.h")
44
file(GLOB xpu_cpp "xpu/*.cpp")
5-
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp")
6-
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp")
5+
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/transformers/*.cpp")
6+
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp")
77

88
list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
99
list(APPEND ATen_XPU_NATIVE_CPP_SRCS ${xpu_native_cpp})
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
#include <ATen/NestedTensorImpl.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/native/nested/NestedTensorUtils.h>
4+
#include <ATen/native/transformers/attention.h>
5+
#include <ATen/native/transformers/sdp_utils_cpp.h>
6+
7+
#ifndef AT_PER_OPERATOR_HEADERS
8+
#include <ATen/Functions.h>
9+
#include <ATen/NativeFunctions.h>
10+
#else
11+
#include <ATen/ops/empty_like.h>
12+
#include <ATen/ops/linear.h>
13+
#include <ATen/ops/scaled_dot_product_attention.h>
14+
#include <ATen/ops/split_native.h>
15+
#endif
16+
17+
#include <ATen/native/transformers/SDPUtils.h>
18+
#include <ATen/native/transformers/sycl/AttentionKernels.h>
19+
20+
#include <comm/SYCLContext.h>
21+
22+
namespace at {
23+
namespace native {
24+
25+
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
26+
// Note: current only support contiguous indexing, since nested tensor is all
27+
// contiguous
28+
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_xpu(
29+
const Tensor& qkv,
30+
const Tensor& qkv_bias,
31+
const int64_t num_head) {
32+
// for nested tensor, B is most outer size, but T is not regular, it should be
33+
// the large size on dim1
34+
auto B = qkv.is_nested()
35+
? native::get_nested_tensor_impl(qkv)->get_nested_sizes().size(0)
36+
: qkv.size(0);
37+
38+
auto T = qkv.is_nested() ? native::NestedTensor_get_max_size(
39+
*native::get_nested_tensor_impl(qkv))[0]
40+
: qkv.size(1);
41+
if (qkv.is_nested()) {
42+
// Don't mess with non-nested case for now since it's not set up to fiddle
43+
// with mask size.
44+
45+
// Round T up to next multiple of 8 so as to be able to utilize Tensor
46+
// cores. Otherwise, sometimes with padding, *no* row will have the maximum
47+
// sequence length and so we'll have a non-divisible-by-8 dimension even if
48+
// the model author chose a multiple of 8.
49+
T = T + (8 - (T % 8)) % 8;
50+
}
51+
auto _3D = qkv_bias.size(0);
52+
auto D = _3D / 3;
53+
TORCH_CHECK(D % num_head == 0);
54+
const auto dim_per_head = D / num_head;
55+
56+
// q_k_v B T 3D -> 3, B, num_head, T, dim_per_head
57+
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options());
58+
59+
xpu::_transform_bias_rescale_qkv_kernel(
60+
qkv, qkv_bias, num_head, q_k_v, B, T, D, dim_per_head);
61+
62+
auto q_k_v_s =
63+
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
64+
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
65+
}
66+
67+
static bool check_for_seq_len_1_nested_tensor(
68+
sdp::sdp_params params,
69+
bool debug) {
70+
// When this function is called we are assured that the nt is dim==4
71+
if (!params.query.is_nested()) {
72+
return true;
73+
}
74+
75+
const auto nt_q_tensor_impl =
76+
at::native::get_nested_tensor_impl(params.query);
77+
const at::Tensor& sizes = nt_q_tensor_impl->get_nested_sizes();
78+
auto* sizes_ptr = sizes.data_ptr<int64_t>();
79+
const int64_t n_tensors = params.query.size(0);
80+
const int64_t size_tensor_stride = sizes.stride(0);
81+
82+
// This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
83+
for (const auto i : c10::irange(n_tensors)) {
84+
if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) {
85+
if (debug) {
86+
TORCH_WARN(
87+
"Packed projection for fused kernels does not support sequence_length <= 1");
88+
}
89+
return false;
90+
}
91+
}
92+
93+
return true;
94+
}
95+
96+
int64_t _fused_sdp_choice_xpu(
97+
const Tensor& query,
98+
const Tensor& key,
99+
const Tensor& value,
100+
const std::optional<Tensor>& attn_mask_,
101+
double dropout_p,
102+
bool is_causal,
103+
std::optional<double> scale,
104+
bool enable_gqa) {
105+
// We have implemented efficient_attention backend with xetla, flash_attention
106+
// backend is not supported now, which will be implemented in the future. So
107+
// we provide two backends here.
108+
sdp::sdp_params kernel_params{
109+
query, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
110+
// Because TORCHCHECK checks if condition is true we negate debug so that
111+
// The statements will be printed when debug is true
112+
bool print_debug = false;
113+
sdp::SDPBackend backend =
114+
sdp::can_use_mem_efficient_attention(kernel_params, print_debug)
115+
? sdp::SDPBackend::efficient_attention
116+
: sdp::SDPBackend::math;
117+
if (backend == sdp::SDPBackend::error) {
118+
TORCH_CHECK(
119+
false,
120+
"No viable backend for scaled_dot_product_attention was found. ",
121+
"This is likely due to turning off both the math kernel and the fused kernels.");
122+
}
123+
return static_cast<int64_t>(backend);
124+
}
125+
126+
std::tuple<Tensor, Tensor> native_multi_head_attention_xpu(
127+
const Tensor& query,
128+
const Tensor& key,
129+
const Tensor& value,
130+
const int64_t embed_dim,
131+
const int64_t num_head,
132+
const Tensor& qkv_weight,
133+
const Tensor& qkv_bias,
134+
const Tensor& proj_weight,
135+
const Tensor& proj_bias,
136+
const std::optional<Tensor>& mask,
137+
bool need_weights,
138+
bool average_attn_weights,
139+
const std::optional<int64_t> mask_type) {
140+
// query shape: [B, T, D]
141+
// qkv_weight shape: [3 * D, D]
142+
143+
TORCH_CHECK(
144+
!mask || !query.is_nested(),
145+
"NestedTensor with mask is not supported yet");
146+
const auto D = embed_dim;
147+
TORCH_CHECK(
148+
query.dim() == 3, "expected 3-D `query`, got ", query.dim(), "-D tensor");
149+
TORCH_CHECK(
150+
query.is_nested() || query.sizes()[2] == embed_dim,
151+
"passed-in embed_dim ",
152+
embed_dim,
153+
" didn't match last dim of query ",
154+
query.sizes()[2]);
155+
TORCH_CHECK(
156+
key.dim() == 3, "expected 3-D `key`, got ", key.dim(), "-D tensor");
157+
TORCH_CHECK(
158+
value.dim() == 3, "expected 3-D `value`, got ", value.dim(), "-D tensor");
159+
TORCH_CHECK(
160+
query.is_nested() || key.is_nested() || value.is_nested() ||
161+
(query.sizes() == key.sizes() && key.sizes() == value.sizes()),
162+
"expected `query`/`key`/`value` shapes to match");
163+
TORCH_CHECK(
164+
qkv_weight.dim() == 2,
165+
"expected 2-D `qkv_weight`, got ",
166+
qkv_weight.dim(),
167+
"-D tensor");
168+
TORCH_CHECK(
169+
D * 3 == qkv_weight.sizes()[0],
170+
"expected `qkv_weight` first dim to be 3x embed_dim");
171+
TORCH_CHECK(
172+
D == qkv_weight.sizes()[1],
173+
"expected `qkv_weight` second dim to be embed_Dim");
174+
TORCH_CHECK(
175+
qkv_bias.dim() == 1,
176+
"expected 1-D `qkv_bias`, got ",
177+
qkv_bias.dim(),
178+
"-D tensor");
179+
TORCH_CHECK(
180+
qkv_bias.sizes()[0] == 3 * D,
181+
"expected `qkv_bias` first dim and first dim of query to be equal");
182+
TORCH_CHECK(
183+
D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");
184+
185+
#ifndef NDEBUG
186+
const auto B = query.is_nested()
187+
? native::get_nested_tensor_impl(query)->get_nested_sizes().size(0)
188+
: query.sizes()[0];
189+
auto T = query.is_nested() ? 0 : query.sizes()[1];
190+
191+
#endif
192+
const auto dim_per_head = D / num_head;
193+
if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 &&
194+
!need_weights) {
195+
// We have not done linear projection yet but the input for SDP
196+
// Is expected to be 4 dimensional. We "cheaply" create view tensors
197+
// That will then be used for checking hot path conditions with
198+
// select_sd_backend
199+
auto q =
200+
query.view({query.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
201+
auto k =
202+
key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
203+
auto v =
204+
value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
205+
206+
sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false};
207+
auto backend = static_cast<sdp::SDPBackend>(
208+
_fused_sdp_choice_xpu(q, k, v, mask, 0.0, false, {}, false));
209+
210+
// strides from packed projection for nested tensors when seq_len is 1 will
211+
// be and will trigger a contiguous call in the kernel, so we prevent this
212+
bool no_seq_len_1_nested = query.is_nested()
213+
? check_for_seq_len_1_nested_tensor(kernel_params, false)
214+
: true;
215+
// The API for transformer_encoder is a mask of shape (Batch_Size,
216+
// Seq_len_q) For mem-eff attention this will cause the expand call to error
217+
// For now I am going to turn of that path not have to deal with all the
218+
// annoying Mask type shape grossness
219+
if (!mask.has_value() && no_seq_len_1_nested &&
220+
(backend == sdp::SDPBackend::flash_attention ||
221+
backend == sdp::SDPBackend::efficient_attention)) {
222+
auto x = at::linear(query, qkv_weight, qkv_bias);
223+
auto chunks = x.chunk(3, -1);
224+
auto x_size_0 = x.size(0);
225+
226+
chunks[0] = (chunks[0].view({x_size_0, -1, num_head, dim_per_head}))
227+
.transpose(1, 2);
228+
chunks[1] = (chunks[1].view({x_size_0, -1, num_head, dim_per_head}))
229+
.transpose(1, 2);
230+
chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head}))
231+
.transpose(1, 2);
232+
auto y = at::scaled_dot_product_attention(
233+
chunks[0], chunks[1], chunks[2], mask, 0.0, false, std::nullopt);
234+
235+
auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim});
236+
return std::make_tuple(
237+
at::linear(past_sdp, proj_weight, proj_bias), Tensor());
238+
}
239+
// Returned math or error lets not use it
240+
}
241+
242+
// shape: [B, T, 3 x D]
243+
auto qkv = native::qkv_projection(query, key, value, embed_dim, qkv_weight);
244+
245+
if (!qkv.is_nested() && qkv.numel() == 0) {
246+
if (query.is_nested()) {
247+
return std::make_tuple(Tensor(), Tensor());
248+
}
249+
return std::make_tuple(at::empty_like(query), Tensor());
250+
}
251+
252+
#ifndef NDEBUG
253+
if (!query.is_nested() || !qkv.is_nested()) {
254+
if (query.is_nested()) {
255+
T = qkv.size(1);
256+
}
257+
native::debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
258+
}
259+
#endif
260+
261+
#ifdef DEBUG_PRINT_EACH_STEP
262+
if (!qkv.is_nested()) {
263+
std::cerr << "qkv: " << qkv << std::endl;
264+
}
265+
#endif
266+
// shape: 3 x [B, num_head, T, dim_per_head]
267+
auto q_k_v = transform_bias_rescale_qkv_xpu(qkv, qkv_bias, num_head);
268+
qkv = Tensor(); // Not used any more, allow free
269+
auto& q = std::get<0>(q_k_v);
270+
const auto& k = std::get<1>(q_k_v);
271+
const auto& v = std::get<2>(q_k_v);
272+
#ifndef NDEBUG
273+
native::debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
274+
native::debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
275+
native::debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
276+
#endif
277+
#ifdef DEBUG_PRINT_EACH_STEP
278+
std::cerr << "q: " << q << std::endl;
279+
std::cerr << "k: " << k << std::endl;
280+
std::cerr << "v: " << v << std::endl;
281+
#endif
282+
283+
// shape: [B, num_head, T, T]
284+
auto qkt = native::bmm_nt(q, k);
285+
// q & k are dead but cannot be freed because they were packed with v
286+
#ifndef NDEBUG
287+
native::debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
288+
#endif
289+
#ifdef DEBUG_PRINT_EACH_STEP
290+
std::cerr << "qkt: " << qkt << std::endl;
291+
#endif
292+
293+
// shape: [B, num_head, T, T]
294+
// TODO: long-term, have a kernel that works with
295+
// NestedTensor directly if there is no mask passed
296+
qkt = native::masked_softmax(qkt, mask, query, mask_type);
297+
#ifdef DEBUG_PRINT_EACH_STEP
298+
std::cerr << "qkt after softmax: " << qkt << std::endl;
299+
#endif
300+
301+
// shape: [B, num_head, T, dim_per_head]
302+
// reuse storage for q; we're done with it
303+
auto attn_ctx = native::bmm_nn(q, qkt, v);
304+
// qkv is not dead; we just reused storage for q!
305+
if (!need_weights) {
306+
qkt = Tensor();
307+
}
308+
#ifndef NDEBUG
309+
native::debug_assert_shape(
310+
__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
311+
#endif
312+
#ifdef DEBUG_PRINT_EACH_STEP
313+
std::cerr << "attn_ctx: " << attn_ctx << std::endl;
314+
#endif
315+
316+
// shape: [B, T, D]
317+
// Fuse transform_0213 inside
318+
auto proj = native::transform0213_gemm_nt_bias(
319+
attn_ctx, proj_weight, proj_bias, query);
320+
#ifndef NDEBUG
321+
native::debug_assert_shape(__LINE__, proj, {B, T, D});
322+
#endif
323+
if (need_weights && average_attn_weights) {
324+
// weights are not needed for full transformer, so don't worry too
325+
// much about performance -- we implement this just to make use
326+
// cases that don't disable need_weights still get some speedup.
327+
qkt = qkt.sum(1);
328+
qkt /= num_head;
329+
}
330+
return std::make_tuple(std::move(proj), std::move(qkt));
331+
}
332+
333+
} // namespace native
334+
} // namespace at

0 commit comments

Comments
 (0)