|
| 1 | +#include <ATen/NestedTensorImpl.h> |
| 2 | +#include <ATen/core/Tensor.h> |
| 3 | +#include <ATen/native/nested/NestedTensorUtils.h> |
| 4 | +#include <ATen/native/transformers/attention.h> |
| 5 | +#include <ATen/native/transformers/sdp_utils_cpp.h> |
| 6 | + |
| 7 | +#ifndef AT_PER_OPERATOR_HEADERS |
| 8 | +#include <ATen/Functions.h> |
| 9 | +#include <ATen/NativeFunctions.h> |
| 10 | +#else |
| 11 | +#include <ATen/ops/empty_like.h> |
| 12 | +#include <ATen/ops/linear.h> |
| 13 | +#include <ATen/ops/scaled_dot_product_attention.h> |
| 14 | +#include <ATen/ops/split_native.h> |
| 15 | +#endif |
| 16 | + |
| 17 | +#include <ATen/native/transformers/SDPUtils.h> |
| 18 | +#include <ATen/native/transformers/sycl/AttentionKernels.h> |
| 19 | + |
| 20 | +#include <comm/SYCLContext.h> |
| 21 | + |
| 22 | +namespace at { |
| 23 | +namespace native { |
| 24 | + |
| 25 | +// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias |
| 26 | +// Note: current only support contiguous indexing, since nested tensor is all |
| 27 | +// contiguous |
| 28 | +std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_xpu( |
| 29 | + const Tensor& qkv, |
| 30 | + const Tensor& qkv_bias, |
| 31 | + const int64_t num_head) { |
| 32 | + // for nested tensor, B is most outer size, but T is not regular, it should be |
| 33 | + // the large size on dim1 |
| 34 | + auto B = qkv.is_nested() |
| 35 | + ? native::get_nested_tensor_impl(qkv)->get_nested_sizes().size(0) |
| 36 | + : qkv.size(0); |
| 37 | + |
| 38 | + auto T = qkv.is_nested() ? native::NestedTensor_get_max_size( |
| 39 | + *native::get_nested_tensor_impl(qkv))[0] |
| 40 | + : qkv.size(1); |
| 41 | + if (qkv.is_nested()) { |
| 42 | + // Don't mess with non-nested case for now since it's not set up to fiddle |
| 43 | + // with mask size. |
| 44 | + |
| 45 | + // Round T up to next multiple of 8 so as to be able to utilize Tensor |
| 46 | + // cores. Otherwise, sometimes with padding, *no* row will have the maximum |
| 47 | + // sequence length and so we'll have a non-divisible-by-8 dimension even if |
| 48 | + // the model author chose a multiple of 8. |
| 49 | + T = T + (8 - (T % 8)) % 8; |
| 50 | + } |
| 51 | + auto _3D = qkv_bias.size(0); |
| 52 | + auto D = _3D / 3; |
| 53 | + TORCH_CHECK(D % num_head == 0); |
| 54 | + const auto dim_per_head = D / num_head; |
| 55 | + |
| 56 | + // q_k_v B T 3D -> 3, B, num_head, T, dim_per_head |
| 57 | + auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options()); |
| 58 | + |
| 59 | + xpu::_transform_bias_rescale_qkv_kernel( |
| 60 | + qkv, qkv_bias, num_head, q_k_v, B, T, D, dim_per_head); |
| 61 | + |
| 62 | + auto q_k_v_s = |
| 63 | + at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0); |
| 64 | + return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]); |
| 65 | +} |
| 66 | + |
| 67 | +static bool check_for_seq_len_1_nested_tensor( |
| 68 | + sdp::sdp_params params, |
| 69 | + bool debug) { |
| 70 | + // When this function is called we are assured that the nt is dim==4 |
| 71 | + if (!params.query.is_nested()) { |
| 72 | + return true; |
| 73 | + } |
| 74 | + |
| 75 | + const auto nt_q_tensor_impl = |
| 76 | + at::native::get_nested_tensor_impl(params.query); |
| 77 | + const at::Tensor& sizes = nt_q_tensor_impl->get_nested_sizes(); |
| 78 | + auto* sizes_ptr = sizes.data_ptr<int64_t>(); |
| 79 | + const int64_t n_tensors = params.query.size(0); |
| 80 | + const int64_t size_tensor_stride = sizes.stride(0); |
| 81 | + |
| 82 | + // This is being called inside sdp with shape [batch, heads, {seq_len}, dim] |
| 83 | + for (const auto i : c10::irange(n_tensors)) { |
| 84 | + if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) { |
| 85 | + if (debug) { |
| 86 | + TORCH_WARN( |
| 87 | + "Packed projection for fused kernels does not support sequence_length <= 1"); |
| 88 | + } |
| 89 | + return false; |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + return true; |
| 94 | +} |
| 95 | + |
| 96 | +int64_t _fused_sdp_choice_xpu( |
| 97 | + const Tensor& query, |
| 98 | + const Tensor& key, |
| 99 | + const Tensor& value, |
| 100 | + const std::optional<Tensor>& attn_mask_, |
| 101 | + double dropout_p, |
| 102 | + bool is_causal, |
| 103 | + std::optional<double> scale, |
| 104 | + bool enable_gqa) { |
| 105 | + // We have implemented efficient_attention backend with xetla, flash_attention |
| 106 | + // backend is not supported now, which will be implemented in the future. So |
| 107 | + // we provide two backends here. |
| 108 | + sdp::sdp_params kernel_params{ |
| 109 | + query, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; |
| 110 | + // Because TORCHCHECK checks if condition is true we negate debug so that |
| 111 | + // The statements will be printed when debug is true |
| 112 | + bool print_debug = false; |
| 113 | + sdp::SDPBackend backend = |
| 114 | + sdp::can_use_mem_efficient_attention(kernel_params, print_debug) |
| 115 | + ? sdp::SDPBackend::efficient_attention |
| 116 | + : sdp::SDPBackend::math; |
| 117 | + if (backend == sdp::SDPBackend::error) { |
| 118 | + TORCH_CHECK( |
| 119 | + false, |
| 120 | + "No viable backend for scaled_dot_product_attention was found. ", |
| 121 | + "This is likely due to turning off both the math kernel and the fused kernels."); |
| 122 | + } |
| 123 | + return static_cast<int64_t>(backend); |
| 124 | +} |
| 125 | + |
| 126 | +std::tuple<Tensor, Tensor> native_multi_head_attention_xpu( |
| 127 | + const Tensor& query, |
| 128 | + const Tensor& key, |
| 129 | + const Tensor& value, |
| 130 | + const int64_t embed_dim, |
| 131 | + const int64_t num_head, |
| 132 | + const Tensor& qkv_weight, |
| 133 | + const Tensor& qkv_bias, |
| 134 | + const Tensor& proj_weight, |
| 135 | + const Tensor& proj_bias, |
| 136 | + const std::optional<Tensor>& mask, |
| 137 | + bool need_weights, |
| 138 | + bool average_attn_weights, |
| 139 | + const std::optional<int64_t> mask_type) { |
| 140 | + // query shape: [B, T, D] |
| 141 | + // qkv_weight shape: [3 * D, D] |
| 142 | + |
| 143 | + TORCH_CHECK( |
| 144 | + !mask || !query.is_nested(), |
| 145 | + "NestedTensor with mask is not supported yet"); |
| 146 | + const auto D = embed_dim; |
| 147 | + TORCH_CHECK( |
| 148 | + query.dim() == 3, "expected 3-D `query`, got ", query.dim(), "-D tensor"); |
| 149 | + TORCH_CHECK( |
| 150 | + query.is_nested() || query.sizes()[2] == embed_dim, |
| 151 | + "passed-in embed_dim ", |
| 152 | + embed_dim, |
| 153 | + " didn't match last dim of query ", |
| 154 | + query.sizes()[2]); |
| 155 | + TORCH_CHECK( |
| 156 | + key.dim() == 3, "expected 3-D `key`, got ", key.dim(), "-D tensor"); |
| 157 | + TORCH_CHECK( |
| 158 | + value.dim() == 3, "expected 3-D `value`, got ", value.dim(), "-D tensor"); |
| 159 | + TORCH_CHECK( |
| 160 | + query.is_nested() || key.is_nested() || value.is_nested() || |
| 161 | + (query.sizes() == key.sizes() && key.sizes() == value.sizes()), |
| 162 | + "expected `query`/`key`/`value` shapes to match"); |
| 163 | + TORCH_CHECK( |
| 164 | + qkv_weight.dim() == 2, |
| 165 | + "expected 2-D `qkv_weight`, got ", |
| 166 | + qkv_weight.dim(), |
| 167 | + "-D tensor"); |
| 168 | + TORCH_CHECK( |
| 169 | + D * 3 == qkv_weight.sizes()[0], |
| 170 | + "expected `qkv_weight` first dim to be 3x embed_dim"); |
| 171 | + TORCH_CHECK( |
| 172 | + D == qkv_weight.sizes()[1], |
| 173 | + "expected `qkv_weight` second dim to be embed_Dim"); |
| 174 | + TORCH_CHECK( |
| 175 | + qkv_bias.dim() == 1, |
| 176 | + "expected 1-D `qkv_bias`, got ", |
| 177 | + qkv_bias.dim(), |
| 178 | + "-D tensor"); |
| 179 | + TORCH_CHECK( |
| 180 | + qkv_bias.sizes()[0] == 3 * D, |
| 181 | + "expected `qkv_bias` first dim and first dim of query to be equal"); |
| 182 | + TORCH_CHECK( |
| 183 | + D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`"); |
| 184 | + |
| 185 | +#ifndef NDEBUG |
| 186 | + const auto B = query.is_nested() |
| 187 | + ? native::get_nested_tensor_impl(query)->get_nested_sizes().size(0) |
| 188 | + : query.sizes()[0]; |
| 189 | + auto T = query.is_nested() ? 0 : query.sizes()[1]; |
| 190 | + |
| 191 | +#endif |
| 192 | + const auto dim_per_head = D / num_head; |
| 193 | + if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 && |
| 194 | + !need_weights) { |
| 195 | + // We have not done linear projection yet but the input for SDP |
| 196 | + // Is expected to be 4 dimensional. We "cheaply" create view tensors |
| 197 | + // That will then be used for checking hot path conditions with |
| 198 | + // select_sd_backend |
| 199 | + auto q = |
| 200 | + query.view({query.size(0), -1, num_head, dim_per_head}).transpose(1, 2); |
| 201 | + auto k = |
| 202 | + key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2); |
| 203 | + auto v = |
| 204 | + value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2); |
| 205 | + |
| 206 | + sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false}; |
| 207 | + auto backend = static_cast<sdp::SDPBackend>( |
| 208 | + _fused_sdp_choice_xpu(q, k, v, mask, 0.0, false, {}, false)); |
| 209 | + |
| 210 | + // strides from packed projection for nested tensors when seq_len is 1 will |
| 211 | + // be and will trigger a contiguous call in the kernel, so we prevent this |
| 212 | + bool no_seq_len_1_nested = query.is_nested() |
| 213 | + ? check_for_seq_len_1_nested_tensor(kernel_params, false) |
| 214 | + : true; |
| 215 | + // The API for transformer_encoder is a mask of shape (Batch_Size, |
| 216 | + // Seq_len_q) For mem-eff attention this will cause the expand call to error |
| 217 | + // For now I am going to turn of that path not have to deal with all the |
| 218 | + // annoying Mask type shape grossness |
| 219 | + if (!mask.has_value() && no_seq_len_1_nested && |
| 220 | + (backend == sdp::SDPBackend::flash_attention || |
| 221 | + backend == sdp::SDPBackend::efficient_attention)) { |
| 222 | + auto x = at::linear(query, qkv_weight, qkv_bias); |
| 223 | + auto chunks = x.chunk(3, -1); |
| 224 | + auto x_size_0 = x.size(0); |
| 225 | + |
| 226 | + chunks[0] = (chunks[0].view({x_size_0, -1, num_head, dim_per_head})) |
| 227 | + .transpose(1, 2); |
| 228 | + chunks[1] = (chunks[1].view({x_size_0, -1, num_head, dim_per_head})) |
| 229 | + .transpose(1, 2); |
| 230 | + chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head})) |
| 231 | + .transpose(1, 2); |
| 232 | + auto y = at::scaled_dot_product_attention( |
| 233 | + chunks[0], chunks[1], chunks[2], mask, 0.0, false, std::nullopt); |
| 234 | + |
| 235 | + auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim}); |
| 236 | + return std::make_tuple( |
| 237 | + at::linear(past_sdp, proj_weight, proj_bias), Tensor()); |
| 238 | + } |
| 239 | + // Returned math or error lets not use it |
| 240 | + } |
| 241 | + |
| 242 | + // shape: [B, T, 3 x D] |
| 243 | + auto qkv = native::qkv_projection(query, key, value, embed_dim, qkv_weight); |
| 244 | + |
| 245 | + if (!qkv.is_nested() && qkv.numel() == 0) { |
| 246 | + if (query.is_nested()) { |
| 247 | + return std::make_tuple(Tensor(), Tensor()); |
| 248 | + } |
| 249 | + return std::make_tuple(at::empty_like(query), Tensor()); |
| 250 | + } |
| 251 | + |
| 252 | +#ifndef NDEBUG |
| 253 | + if (!query.is_nested() || !qkv.is_nested()) { |
| 254 | + if (query.is_nested()) { |
| 255 | + T = qkv.size(1); |
| 256 | + } |
| 257 | + native::debug_assert_shape(__LINE__, qkv, {B, T, 3 * D}); |
| 258 | + } |
| 259 | +#endif |
| 260 | + |
| 261 | +#ifdef DEBUG_PRINT_EACH_STEP |
| 262 | + if (!qkv.is_nested()) { |
| 263 | + std::cerr << "qkv: " << qkv << std::endl; |
| 264 | + } |
| 265 | +#endif |
| 266 | + // shape: 3 x [B, num_head, T, dim_per_head] |
| 267 | + auto q_k_v = transform_bias_rescale_qkv_xpu(qkv, qkv_bias, num_head); |
| 268 | + qkv = Tensor(); // Not used any more, allow free |
| 269 | + auto& q = std::get<0>(q_k_v); |
| 270 | + const auto& k = std::get<1>(q_k_v); |
| 271 | + const auto& v = std::get<2>(q_k_v); |
| 272 | +#ifndef NDEBUG |
| 273 | + native::debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head}); |
| 274 | + native::debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head}); |
| 275 | + native::debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head}); |
| 276 | +#endif |
| 277 | +#ifdef DEBUG_PRINT_EACH_STEP |
| 278 | + std::cerr << "q: " << q << std::endl; |
| 279 | + std::cerr << "k: " << k << std::endl; |
| 280 | + std::cerr << "v: " << v << std::endl; |
| 281 | +#endif |
| 282 | + |
| 283 | + // shape: [B, num_head, T, T] |
| 284 | + auto qkt = native::bmm_nt(q, k); |
| 285 | + // q & k are dead but cannot be freed because they were packed with v |
| 286 | +#ifndef NDEBUG |
| 287 | + native::debug_assert_shape(__LINE__, qkt, {B, num_head, T, T}); |
| 288 | +#endif |
| 289 | +#ifdef DEBUG_PRINT_EACH_STEP |
| 290 | + std::cerr << "qkt: " << qkt << std::endl; |
| 291 | +#endif |
| 292 | + |
| 293 | + // shape: [B, num_head, T, T] |
| 294 | + // TODO: long-term, have a kernel that works with |
| 295 | + // NestedTensor directly if there is no mask passed |
| 296 | + qkt = native::masked_softmax(qkt, mask, query, mask_type); |
| 297 | +#ifdef DEBUG_PRINT_EACH_STEP |
| 298 | + std::cerr << "qkt after softmax: " << qkt << std::endl; |
| 299 | +#endif |
| 300 | + |
| 301 | + // shape: [B, num_head, T, dim_per_head] |
| 302 | + // reuse storage for q; we're done with it |
| 303 | + auto attn_ctx = native::bmm_nn(q, qkt, v); |
| 304 | + // qkv is not dead; we just reused storage for q! |
| 305 | + if (!need_weights) { |
| 306 | + qkt = Tensor(); |
| 307 | + } |
| 308 | +#ifndef NDEBUG |
| 309 | + native::debug_assert_shape( |
| 310 | + __LINE__, attn_ctx, {B, num_head, T, dim_per_head}); |
| 311 | +#endif |
| 312 | +#ifdef DEBUG_PRINT_EACH_STEP |
| 313 | + std::cerr << "attn_ctx: " << attn_ctx << std::endl; |
| 314 | +#endif |
| 315 | + |
| 316 | + // shape: [B, T, D] |
| 317 | + // Fuse transform_0213 inside |
| 318 | + auto proj = native::transform0213_gemm_nt_bias( |
| 319 | + attn_ctx, proj_weight, proj_bias, query); |
| 320 | +#ifndef NDEBUG |
| 321 | + native::debug_assert_shape(__LINE__, proj, {B, T, D}); |
| 322 | +#endif |
| 323 | + if (need_weights && average_attn_weights) { |
| 324 | + // weights are not needed for full transformer, so don't worry too |
| 325 | + // much about performance -- we implement this just to make use |
| 326 | + // cases that don't disable need_weights still get some speedup. |
| 327 | + qkt = qkt.sum(1); |
| 328 | + qkt /= num_head; |
| 329 | + } |
| 330 | + return std::make_tuple(std::move(proj), std::move(qkt)); |
| 331 | +} |
| 332 | + |
| 333 | +} // namespace native |
| 334 | +} // namespace at |
0 commit comments