|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
1 | 7 | from typing import List, Optional |
2 | 8 |
|
3 | 9 | import torch |
|
239 | 245 | long col = 0; |
240 | 246 | for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { |
241 | 247 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col); |
242 | | - auto tmp1 = tmp0 * vec_sum_scale; |
243 | | - auto tmp2 = tmp1.round(); |
244 | | - auto tmp3 = tmp2 + vec_beta1; |
| 248 | + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); |
| 249 | + auto tmp3 = tmp1.round(); |
245 | 250 | auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); |
246 | | - store(tmp_out + col, tmp4); |
247 | 251 | auto tmp6 = at::vec::convert<int32_t>(tmp4); |
| 252 | + auto tmp7 = at::vec::convert<scalar_t>(tmp6); |
| 253 | + tmp7.store(tmp_out + col, vec_size); |
248 | 254 | vec_tmp_sum += tmp6; |
249 | 255 | } |
250 | 256 | if (col < kvBlockSize) { |
251 | 257 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col); |
252 | | - auto tmp1 = tmp0 * vec_sum_scale; |
253 | | - auto tmp2 = tmp1.round(); |
254 | | - auto tmp3 = tmp2 + vec_beta1; |
| 258 | + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); |
| 259 | + auto tmp3 = tmp1.round(); |
255 | 260 | auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); |
256 | | - store(tmp_out + col, tmp4, kvBlockSize - col); |
257 | 261 | auto tmp6 = at::vec::convert<int32_t>(tmp4); |
| 262 | + auto tmp7 = at::vec::convert<scalar_t>(tmp6); |
| 263 | + tmp7.store(tmp_out + col, kvBlockSize - col); |
258 | 264 | vec_tmp_sum = at::vec::Vectorized<int32_t>::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col); |
259 | 265 | } |
260 | 266 | sum_a_ptr[row] += vec_tmp_sum.reduce_add() * beta2; |
|
341 | 347 | long col = 0; |
342 | 348 | for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { |
343 | 349 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col); |
344 | | - auto tmp1 = tmp0 * vec_sum_scale; |
345 | | - auto tmp2 = tmp1.round(); |
346 | | - auto tmp3 = tmp2 + vec_beta1; |
| 350 | + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); |
| 351 | + auto tmp3 = tmp1.round(); |
347 | 352 | auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); |
348 | 353 | store(tmp_out + col, tmp4); |
349 | 354 | } |
350 | 355 | if (col < kvBlockSize) { |
351 | 356 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col); |
352 | | - auto tmp1 = tmp0 * vec_sum_scale; |
353 | | - auto tmp2 = tmp1.round(); |
354 | | - auto tmp3 = tmp2 + vec_beta1; |
| 357 | + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); |
| 358 | + auto tmp3 = tmp1.round(); |
355 | 359 | auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); |
356 | 360 | store(tmp_out + col, tmp4, kvBlockSize - col); |
357 | 361 | } |
|
406 | 410 | auto tmp2 = tmp1 - vec_sum_a; |
407 | 411 | auto tmp3 = tmp2 + vec_beta1; |
408 | 412 | auto tmp4 = at::vec::convert<float>(tmp3); |
409 | | - auto tmp5 = tmp4 * vec_alpha; |
410 | | - auto tmp6 = tmp5.round(); |
411 | | - auto tmp7 = tmp6 + vec_beta2; |
| 413 | + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); |
| 414 | + auto tmp7 = tmp5.round(); |
412 | 415 | auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); |
413 | 416 | store(tmp_out + col, tmp8); |
414 | 417 | } |
|
419 | 422 | auto tmp2 = tmp1 - vec_sum_a; |
420 | 423 | auto tmp3 = tmp2 + vec_beta1; |
421 | 424 | auto tmp4 = at::vec::convert<float>(tmp3); |
422 | | - auto tmp5 = tmp4 * vec_alpha; |
423 | | - auto tmp6 = tmp5.round(); |
424 | | - auto tmp7 = tmp6 + vec_beta2; |
| 425 | + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); |
| 426 | + auto tmp7 = tmp5.round(); |
425 | 427 | auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); |
426 | 428 | store(tmp_out + col, tmp8, N - col); |
427 | 429 | } |
|
463 | 465 | auto tmp3 = tmp1 - vec_sum_a; |
464 | 466 | // auto tmp3 = tmp2 + vec_beta1; |
465 | 467 | auto tmp4 = at::vec::convert<float>(tmp3); |
466 | | - auto tmp5 = tmp4 * vec_alpha; |
467 | | - auto tmp6 = tmp5.round(); |
468 | | - auto tmp7 = tmp6 + vec_beta2; |
| 468 | + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); |
| 469 | + auto tmp7 = tmp5.round(); |
469 | 470 | auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); |
470 | 471 | store(tmp_out + col, tmp8); |
471 | 472 | } |
472 | 473 | if (col < N) { |
473 | 474 | auto tmp1 = at::vec::Vectorized<int32_t>::loadu(tmp_in + col, N - col); |
474 | 475 | auto tmp3 = tmp1 - vec_sum_a; |
475 | 476 | auto tmp4 = at::vec::convert<float>(tmp3); |
476 | | - auto tmp5 = tmp4 * vec_alpha; |
477 | | - auto tmp6 = tmp5.round(); |
478 | | - auto tmp7 = tmp6 + vec_beta2; |
| 477 | + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); |
| 478 | + auto tmp7 = tmp5.round(); |
479 | 479 | auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); |
480 | 480 | store(tmp_out + col, tmp8, N - col); |
481 | 481 | } |
|
1384 | 1384 | q_sum_ptr, static_cast<int32_t>(0), qSplitSize); |
1385 | 1385 | {%- endif %} |
1386 | 1386 | const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; |
1387 | | - |
| 1387 | +
|
1388 | 1388 | for (int64_t l = 0; l < rkvSlice; l++) { |
1389 | 1389 | int64_t n = l * kvSplitSize; |
1390 | 1390 | int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); |
|
0 commit comments