Skip to content

Commit 3577306

Browse files
authored
[CPU] Improve INT8 SDPA template (#3230)
* [CPU] Improve INT8 SDPA template * Update tail
1 parent d84f5b8 commit 3577306

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from typing import List, Optional
28

39
import torch
@@ -239,22 +245,22 @@
239245
long col = 0;
240246
for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) {
241247
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col);
242-
auto tmp1 = tmp0 * vec_sum_scale;
243-
auto tmp2 = tmp1.round();
244-
auto tmp3 = tmp2 + vec_beta1;
248+
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
249+
auto tmp3 = tmp1.round();
245250
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
246-
store(tmp_out + col, tmp4);
247251
auto tmp6 = at::vec::convert<int32_t>(tmp4);
252+
auto tmp7 = at::vec::convert<scalar_t>(tmp6);
253+
tmp7.store(tmp_out + col, vec_size);
248254
vec_tmp_sum += tmp6;
249255
}
250256
if (col < kvBlockSize) {
251257
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col);
252-
auto tmp1 = tmp0 * vec_sum_scale;
253-
auto tmp2 = tmp1.round();
254-
auto tmp3 = tmp2 + vec_beta1;
258+
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
259+
auto tmp3 = tmp1.round();
255260
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
256-
store(tmp_out + col, tmp4, kvBlockSize - col);
257261
auto tmp6 = at::vec::convert<int32_t>(tmp4);
262+
auto tmp7 = at::vec::convert<scalar_t>(tmp6);
263+
tmp7.store(tmp_out + col, kvBlockSize - col);
258264
vec_tmp_sum = at::vec::Vectorized<int32_t>::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col);
259265
}
260266
sum_a_ptr[row] += vec_tmp_sum.reduce_add() * beta2;
@@ -341,17 +347,15 @@
341347
long col = 0;
342348
for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) {
343349
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col);
344-
auto tmp1 = tmp0 * vec_sum_scale;
345-
auto tmp2 = tmp1.round();
346-
auto tmp3 = tmp2 + vec_beta1;
350+
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
351+
auto tmp3 = tmp1.round();
347352
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
348353
store(tmp_out + col, tmp4);
349354
}
350355
if (col < kvBlockSize) {
351356
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col);
352-
auto tmp1 = tmp0 * vec_sum_scale;
353-
auto tmp2 = tmp1.round();
354-
auto tmp3 = tmp2 + vec_beta1;
357+
auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1);
358+
auto tmp3 = tmp1.round();
355359
auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val);
356360
store(tmp_out + col, tmp4, kvBlockSize - col);
357361
}
@@ -406,9 +410,8 @@
406410
auto tmp2 = tmp1 - vec_sum_a;
407411
auto tmp3 = tmp2 + vec_beta1;
408412
auto tmp4 = at::vec::convert<float>(tmp3);
409-
auto tmp5 = tmp4 * vec_alpha;
410-
auto tmp6 = tmp5.round();
411-
auto tmp7 = tmp6 + vec_beta2;
413+
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
414+
auto tmp7 = tmp5.round();
412415
auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
413416
store(tmp_out + col, tmp8);
414417
}
@@ -419,9 +422,8 @@
419422
auto tmp2 = tmp1 - vec_sum_a;
420423
auto tmp3 = tmp2 + vec_beta1;
421424
auto tmp4 = at::vec::convert<float>(tmp3);
422-
auto tmp5 = tmp4 * vec_alpha;
423-
auto tmp6 = tmp5.round();
424-
auto tmp7 = tmp6 + vec_beta2;
425+
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
426+
auto tmp7 = tmp5.round();
425427
auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
426428
store(tmp_out + col, tmp8, N - col);
427429
}
@@ -463,19 +465,17 @@
463465
auto tmp3 = tmp1 - vec_sum_a;
464466
// auto tmp3 = tmp2 + vec_beta1;
465467
auto tmp4 = at::vec::convert<float>(tmp3);
466-
auto tmp5 = tmp4 * vec_alpha;
467-
auto tmp6 = tmp5.round();
468-
auto tmp7 = tmp6 + vec_beta2;
468+
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
469+
auto tmp7 = tmp5.round();
469470
auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
470471
store(tmp_out + col, tmp8);
471472
}
472473
if (col < N) {
473474
auto tmp1 = at::vec::Vectorized<int32_t>::loadu(tmp_in + col, N - col);
474475
auto tmp3 = tmp1 - vec_sum_a;
475476
auto tmp4 = at::vec::convert<float>(tmp3);
476-
auto tmp5 = tmp4 * vec_alpha;
477-
auto tmp6 = tmp5.round();
478-
auto tmp7 = tmp6 + vec_beta2;
477+
auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2);
478+
auto tmp7 = tmp5.round();
479479
auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val);
480480
store(tmp_out + col, tmp8, N - col);
481481
}
@@ -1384,7 +1384,7 @@
13841384
q_sum_ptr, static_cast<int32_t>(0), qSplitSize);
13851385
{%- endif %}
13861386
const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1;
1387-
1387+
13881388
for (int64_t l = 0; l < rkvSlice; l++) {
13891389
int64_t n = l * kvSplitSize;
13901390
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);

0 commit comments

Comments
 (0)