From 9989aaf6b07d0ec45ee09c94194725cf5b4cdefe Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 5 Aug 2025 06:59:45 +0000 Subject: [PATCH 1/4] [cpu][fp8] support fp8 sdpa for cpu --- .../inductor/test_int8_sdpa_fusion.py | 12 +- test/test_ops.py | 204 +++-- .../cpu/{int8_sdpa.cpp => quantized_sdpa.cpp} | 740 +++++++++++++++++- .../prototype/inductor/fx_passes/README.md | 2 +- .../prototype/inductor/fx_passes/__init__.py | 4 +- .../inductor/fx_passes/int8_sdpa_fusion.py | 396 ---------- .../inductor/fx_passes/qsdpa_fusion.py | 472 +++++++++++ ...nt8_sdpa_lowering.py => qsdpa_lowering.py} | 67 +- 8 files changed, 1336 insertions(+), 561 deletions(-) rename torchao/csrc/cpu/{int8_sdpa.cpp => quantized_sdpa.cpp} (71%) delete mode 100644 torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py create mode 100644 torchao/prototype/inductor/fx_passes/qsdpa_fusion.py rename torchao/prototype/inductor/{int8_sdpa_lowering.py => qsdpa_lowering.py} (67%) diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index ec4f928df2..ad0d093480 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -11,8 +11,8 @@ from torch.testing._internal.inductor_utils import HAS_CPU import torchao -from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import ( - _int8_sdpa_init, +from torchao.prototype.inductor.fx_passes.qsdpa_fusion import ( + _qsdpa_init, custom_pass, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 @@ -120,7 +120,7 @@ def _check_common( ) source_code = "\n".join(source_code) if has_fuse_pattern: - self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) + self.assertGreaterEqual(counters["inductor"]["qsdpa_fuse_attention"], 1) if contains: self.assertTrue( any( @@ -157,7 +157,7 @@ def _check_common( ) @config.patch({"freezing": True}) def _test_sdpa_int8_rewriter(self): - from torch.export import export_for_training + from torch.export import export import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -193,13 +193,13 @@ def _test_sdpa_int8_rewriter(self): ), config.patch(post_grad_custom_pre_pass=custom_pass), ): - _int8_sdpa_init() + _qsdpa_init() quantizer = X86InductorQuantizer() quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) quantizer.set_function_type_qconfig( torch.matmul, quantizer.get_global_quantization_config() ) - export_model = export_for_training( + export_model = export( mod, inputs, strict=True, diff --git a/test/test_ops.py b/test/test_ops.py index faec689a69..d74066db27 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -155,50 +155,101 @@ def _scaled_dot_product_int8_op_ref( out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255) return out.to(torch.uint8) + def _scaled_dot_product_fp8_op_ref( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0, + is_causal=False, + q_scale=1.0, + k_scale=1.0, + v_scale=1.0, + a_scale=1.0, + o_scale=1.0, + ): + q = q.to(torch.float) * q_scale + k = k.to(torch.float) * k_scale + v = v.to(torch.float) * v_scale + scale_factor = 1 / math.sqrt(q.size(-1)) + attn = q @ k.transpose(-2, -1) + + attn = attn * scale_factor + if attn_mask is not None: + attn = attn + attn_mask.to(torch.float) + attn_max = attn.max(dim=-1, keepdim=True).values + attn = attn - attn_max + attn = torch.exp(attn) + attn_sum = torch.sum(attn, dim=-1, keepdim=True) + attn = attn / attn_sum + attn = torch.clamp(attn / a_scale, min=-448, max=448) + attn = attn.to(torch.float8_e4m3fn).to(torch.float) + attn = attn * a_scale + out = attn @ v + out = torch.clamp(out / o_scale, min=-448, max=448) + return out.to(torch.float8_e4m3fn) + @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + not TORCH_VERSION_AT_LEAST_2_7, + reason="quantized sdpa requires torch 2.7 or later", ) @pytest.mark.skipif(not IS_LINUX, reason="only support on linux") @pytest.mark.skipif( "CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"), reason="cpp kernels not built", ) + @parametrize("input_dtype", [torch.uint8, torch.float8_e4m3fn]) @parametrize("batch_size", [56, 120]) @parametrize("n_head", [2, 16]) @parametrize("q_seq_len", [18, 89]) @parametrize("kv_seq_len", [100, 253]) @parametrize("head_dim", [32, 64]) @parametrize("mask_dtype", [None, torch.float32, torch.bfloat16]) - def test_scaled_dot_product_int8_op( - self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype + def test_quantized_scaled_dot_product_op( + self, + input_dtype, + batch_size, + n_head, + q_seq_len, + kv_seq_len, + head_dim, + mask_dtype, ): torch.manual_seed(1234) device = "cpu" - q_scale = float(1.7907238006591797) - q_zp = int(127) - k_scale = float(1.8039721250534058) - k_zp = int(125) - v_scale = float(1.839004635810852) - v_zp = int(127) - a_scale = float(0.003919653594493866) - a_zp = int(120) - o_scale = float(1.8191684484481812) - o_zp = int(128) + if input_dtype == torch.uint8: + q_scale = float(1.7907238006591797) + k_scale = float(1.8039721250534058) + v_scale = float(1.839004635810852) + a_scale = float(0.003919653594493866) + o_scale = float(1.8191684484481812) + q_zp = int(127) + k_zp = int(125) + v_zp = int(127) + a_zp = int(120) + o_zp = int(128) + atol, rtol = 1.0, 5e-6 + else: + q_scale = float(5.96875) + k_scale = float(5.78125) + v_scale = float(0.98046875) + a_scale = float(4.84375) + o_scale = float(3.171875) + atol, rtol = 0.125, 5e-6 q_shape = [batch_size, q_seq_len, n_head, head_dim] kv_shape = [batch_size, kv_seq_len, n_head, head_dim] mask_shape = [batch_size, 1, 1, kv_seq_len] - q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 - k = ( - torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) - * 100 - ) - v = ( - torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) - * 100 - ) - q = q.to(torch.uint8) - k = k.to(torch.uint8) - v = v.to(torch.uint8) + q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) + k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + if input_dtype == torch.uint8: + q *= 100 + k *= 100 + v *= 100 + q = q.to(input_dtype) + k = k.to(input_dtype) + v = v.to(input_dtype) attn_mask = ( torch.randn(mask_shape, dtype=mask_dtype, device=device) if mask_dtype is not None @@ -211,44 +262,71 @@ def test_scaled_dot_product_int8_op( attn_mask.clone() if mask_dtype is not None else None, ) - math_ref = self._scaled_dot_product_int8_op_ref( - q2, - k2, - v2, - attn_mask=attn_mask, - dropout_p=0.0, - is_causal=False, - q_scale=q_scale, - q_zp=q_zp, - k_scale=k_scale, - k_zp=k_zp, - v_scale=v_scale, - v_zp=v_zp, - a_scale=a_scale, - a_zp=a_zp, - o_scale=o_scale, - o_zp=o_zp, - ) - actual = torch.ops.torchao.qscaled_dot_product( - q, - k, - v, - attn_mask=attn_mask_2, - dropout_p=0.0, - is_causal=False, - q_scale=q_scale, - q_zp=q_zp, - k_scale=k_scale, - k_zp=k_zp, - v_scale=v_scale, - v_zp=v_zp, - a_scale=a_scale, - a_zp=a_zp, - o_scale=o_scale, - o_zp=o_zp, - ) - - self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) + if input_dtype == torch.uint8: + math_ref = self._scaled_dot_product_int8_op_ref( + q2, + k2, + v2, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + q_zp=q_zp, + k_scale=k_scale, + k_zp=k_zp, + v_scale=v_scale, + v_zp=v_zp, + a_scale=a_scale, + a_zp=a_zp, + o_scale=o_scale, + o_zp=o_zp, + ) + actual = torch.ops.torchao.qscaled_dot_product( + q, + k, + v, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + q_zp=q_zp, + k_scale=k_scale, + k_zp=k_zp, + v_scale=v_scale, + v_zp=v_zp, + a_scale=a_scale, + a_zp=a_zp, + o_scale=o_scale, + o_zp=o_zp, + ) + else: + math_ref = self._scaled_dot_product_fp8_op_ref( + q2, + k2, + v2, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + a_scale=a_scale, + o_scale=o_scale, + ) + actual = torch.ops.torchao.qscaled_dot_product( + q, + k, + v, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + a_scale=a_scale, + o_scale=o_scale, + ) + self.assertEqual(actual.float(), math_ref.float(), atol=atol, rtol=rtol) instantiate_parametrized_tests(TestOps) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/quantized_sdpa.cpp similarity index 71% rename from torchao/csrc/cpu/int8_sdpa.cpp rename to torchao/csrc/cpu/quantized_sdpa.cpp index a5928f6d9a..1e34b0d5e7 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/quantized_sdpa.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -82,12 +83,55 @@ inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at } template -inline typename std::enable_if_t || std::is_same_v, void> +inline typename std::enable_if_t || std::is_same_v || std::is_same_v, void> _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { auto res = at::vec::convert(src); res.store(dst, size); } +/* +out = val * a + b +is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), + take b as a scalar pointer. +*/ +template +inline void _scale_dequant_attn_mask_fusion_kernel( + T1* a, + T2* b, + const int& size, + T1* out, + const T1& val) { + const auto vec_size1 = at::vec::Vectorized::size(); + const auto vec_size2 = at::vec::Vectorized::size(); + constexpr int64_t T1_n = + (vec_size2 == vec_size1 * 2 && at::vec::is_reduced_floating_point_v) ? 2 : 1; + constexpr int64_t T2_n = 1; + auto vec_scale = at::vec::VectorizedN(val); + int64_t i = 0; + for (; i < size - (size % vec_size2); i += vec_size2) { + auto a_n = at::vec::VectorizedN::loadu(a + i); + at::vec::VectorizedN b_n; + if constexpr(is_b_stride_zero) { + b_n = at::vec::VectorizedN((T1)b[0]); + } else { + b_n = at::vec::VectorizedN::loadu(b + i); + } + auto b_n_convert = at::vec::convert(b_n); + auto res = a_n * vec_scale + b_n_convert; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + T1 tmp1; + if constexpr(is_b_stride_zero) { + tmp1 = (T1)b[0]; + } else { + tmp1 = (T1)b[i]; + } + out[i] = tmp0 * val + tmp1; + } +} + /* 1. dequant 2. add mask @@ -618,7 +662,7 @@ inline void _int_sum_a_contiguous_kernel( // do the transpose: [in_rows, in_cols] -> [in_cols, in_rows] template inline void do_transpose( - scalar_t* src, + const scalar_t* src, scalar_t* dst, int64_t in_rows, int64_t in_cols, @@ -673,7 +717,7 @@ inline void pad_remain_row_col( // copy value_ptr to dst_ptr with padding: [rows, cols] -> [prows, pcols] template inline void copy_value_with_pad( - scalar_t* value_ptr, + const scalar_t* value_ptr, scalar_t* dst_ptr, int rows, int cols, @@ -725,13 +769,122 @@ inline void copy_value_with_pad( } +/* +1. out = a * scale +2. max = max(out) +*/ +template +inline void _mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + _store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + auto reduced_tmp_max = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max); + // Guard against Q*K^T being NaN + max = std::isnan(reduced_tmp_max) ? std::numeric_limits::quiet_NaN() + : std::max(tmp_max, reduced_tmp_max); +} + +/* +1. out = exp(a - val) +2. val = sum(out) +3. quant +*/ +inline void _fp8_exp_reduce_sum_quant_fusion_kernel( + float* a, + const int& size, + at::Float8_e4m3fn* out, + float& val, + const float& scale) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float min_val = -448; + float max_val = 448; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_scale = at::vec::Vectorized(scale); + long i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + auto tmp3 = tmp2 * vec_scale; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(out + i, tmp4); + } + if (i < size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i, size - i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, size - i); + auto tmp3 = tmp2 * vec_scale; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(out + i, tmp4, size - i); + } + val = vec_tmp_sum.reduce_add(); +} + +/* +1. dequant +2. quant +*/ +inline void _fp8_dequant_quant_fusion_kernel( + float* a, + const int& size, + at::Float8_e4m3fn* out, + const float& scale) { + auto vec_size = at::vec::Vectorized::size(); + float min_val = -448; + float max_val = 448; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_scale = at::vec::Vectorized(scale); + long i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + auto tmp2 = at::vec::clamp(tmp1, vec_min_val, vec_max_val); + _store(out + i, tmp2); + } + if (i < size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i, size - i); + auto tmp1 = tmp0 * vec_scale; + auto tmp2 = at::vec::clamp(tmp1, vec_min_val, vec_max_val); + _store(out + i, tmp2, size - i); + } +} + // UINT8 - one parallel loop with u8u8s32 GEMM template = 0> inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( +int8_sdpa_fused_kernel_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -830,9 +983,9 @@ sdpa_int8_fused_kernel_impl( int av_gemm_K = kvSplitSize + av_gemm_K_padding; // Data ptrs - scalar_t* q_data = query.data_ptr(); - scalar_t* k_data = key.data_ptr(); - scalar_t* v_data = value.data_ptr(); + const scalar_t* q_data = query.data_ptr(); + const scalar_t* k_data = key.data_ptr(); + const scalar_t* v_data = value.data_ptr(); mask_t* mask_data = attention_mask.has_value() ? attention_mask.value().data_ptr() : nullptr; @@ -931,7 +1084,7 @@ sdpa_int8_fused_kernel_impl( bool istail = kvBlockSize - b < block_64; int64_t trans_rows = istail ? kvBlockSize - b : block_64; do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + reinterpret_cast(k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN), B_blocked_xform_u8, trans_rows, headSize, @@ -1159,7 +1312,7 @@ template = 0> inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( +int8_sdpa_fused_kernel_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -1622,10 +1775,371 @@ sdpa_int8_fused_kernel_impl( at::native::cpublas::brgemm_release(); } +// FP8 - kernel with f8f8f8 GEMM +template +inline typename std::enable_if_t, void> +fp8_sdpa_fused_kernel_impl( + const at::Tensor& output, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + double dropout_p, + bool is_causal, + std::optional attn_mask, + std::optional scale, + float q_scale, + float k_scale, + float v_scale, + float a_scale, + float o_scale) { + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); + + using accum_t = float; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = calculate_scale(query, scale).expect_float(); + + // Sizes + TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size"); + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (has_attn_mask && attn_mask.value().size(0) > 1) + ? attn_mask.value().stride(0) + : 0; + int64_t mStrideH = + (has_attn_mask && attn_mask.value().size(1) > 1) + ? attn_mask.value().stride(1) + : 0; + int64_t mStrideM = + (has_attn_mask && attn_mask.value().size(2) > 1) + ? attn_mask.value().stride(2) + : 0; + int64_t mStrideN = + (has_attn_mask && attn_mask.value().size(3) > 1) + ? attn_mask.value().stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; + int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + // Pad is needed for packing when K is not even + bool headSize_even = headSize % 4 == 0; + int64_t eheadSize = !headSize_even ? headSize + 4 - headSize % 4: headSize; + int64_t ekvSplitSize = (kvSplitSize % 4 != 0) ? kvSplitSize + 4 - kvSplitSize % 4 : kvSplitSize; + int64_t ekvTail = (kvTail % 4 != 0) ? kvTail + 4 - kvTail % 4 : kvTail; + + // Allocate per thread temp buf (accumulate type) + int64_t size_per_thread = + /* qk */ qSplitSize * kvSplitSize + + /* qk_max */ qSplitSize + + /* qk_sum */ qSplitSize + + /* dst */ qSplitSize * headSize; + + at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(at::kFloat)); + at::Tensor buf_reduced = at::empty( + {num_thread, + qSplitSize, + ekvSplitSize}, + query.options()); + + // Data ptrs + const scalar_t* q_data = query.const_data_ptr(); + const scalar_t* k_data = key.const_data_ptr(); + const scalar_t* v_data = value.const_data_ptr(); + mask_t* mask_data = has_attn_mask + ? attn_mask.value().data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + // accum_t* lse_data = logsumexp.data_ptr(); + accum_t* buf_data = buf.data_ptr(); + scalar_t* buf_reduced_data = buf_reduced.data_ptr(); + + // Buffer to store padding query and packing key/value + int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; + at::Tensor key_t_reorder = at::empty( + {batchSize, num_head, eheadSize, kvSize}, + c10::CppTypeToScalarType::value); + at::Tensor value_t_reorder = at::empty( + {batchSize, num_head, kv_padding_size, headSize}, + c10::CppTypeToScalarType::value); + scalar_t* key_reorder_ptr = key_t_reorder.data_ptr(); + scalar_t* value_reorder_ptr = value_t_reorder.data_ptr(); + + scalar_t* query_padding_ptr = nullptr; + at::Tensor query_t_padding; + if (!headSize_even) { + query_t_padding = at::empty( + {num_thread, qSplitSize, eheadSize}, + c10::CppTypeToScalarType::value); + query_padding_ptr = query_t_padding.data_ptr(); + } + + // Reorder K, V + at::Tensor tranpose_t_reorder = at::empty( + {num_thread, kvSplitSize, headSize}, + c10::CppTypeToScalarType::value); + scalar_t* transpose_buffer_ptr = tranpose_t_reorder.data_ptr(); + at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int ompIdx = at::get_thread_num(); + int64_t i = 0, j = 0, l = 0, n = 0; + scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize; + at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice); + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + + // transpose [kvBlockSize, headSize] -> [headSize, kvBlockSize] + at::native::utils::transpose( + kvBlockSize, + headSize, + /* src */ reinterpret_cast(k_data + i * kStrideB + j * kStrideH + n * kStrideN), + /* ld_src */ kStrideN, + /* dst */ reinterpret_cast(transpose_ptr), + /* ld_dst */ kvBlockSize); + + // Pack [headSize, kvBlockSize] + at::vec::pack_vnni4( + /* src */ reinterpret_cast(transpose_ptr), + /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize), + /* ld_src */ kvBlockSize, + /* K */ headSize, + /* N */ kvBlockSize); + + // Pack [kvBlockSize, headSize] + at::vec::pack_vnni4( + /* src */ reinterpret_cast(v_data + i * vStrideB + j * vStrideH + n * vStrideN), + /* dst */ reinterpret_cast(value_reorder_ptr + + i * num_head * kv_padding_size * headSize + + j * kv_padding_size * headSize + n * headSize), + /* ld_src */ vStrideN, + /* K */ kvBlockSize, + /* N */ headSize); + + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + + at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + accum_t* buf_ptr = buf_data + ompIdx * size_per_thread; + accum_t* qk_data = buf_ptr; + accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; + accum_t* qk_sum_data = qk_max_data + qSplitSize; + accum_t* dst_data = qk_sum_data + qSplitSize; + scalar_t* qk_reduced_data = buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize; + scalar_t* query_t_padding_ptr = !headSize_even + ? query_padding_ptr + ompIdx * qSplitSize * eheadSize + : nullptr; + + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize max and sum + fill_stub(qk_max_data, + -std::numeric_limits::infinity(), qBlockSize); + fill_stub(qk_sum_data, + static_cast(0), qBlockSize); + int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + if (!headSize_even) { + // Pad query if headSize is not even + // [qBlockSize, headSize] -> [qBlockSize, eheadSize] + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + eheadSize, + qStrideM + ); + } + for (int64_t n = 0; n < num_keys; n += kvSplitSize) { + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + int64_t ekvBlockSize = (kvBlockSize % 4 != 0) ? kvBlockSize + 4 - kvBlockSize % 4 : kvBlockSize; + // Calculate scale * q @ k.T + at::native::cpublas::brgemm( + qBlockSize, + kvBlockSize, + eheadSize, + headSize_even ? qStrideM : eheadSize, + kvBlockSize, + kvBlockSize, + false, + !headSize_even + ? query_t_padding_ptr + : q_data + i * qStrideB + j * qStrideH + m * qStrideM, + key_reorder_ptr + i * num_head * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize, + qk_data); + // Apply causal mask, fill unused with -inf + if (is_causal && num_keys - n <= kvSplitSize) { + for (const auto row : c10::irange(qBlockSize)) { + int64_t last_col = m + row - n; + accum_t* row_ptr = qk_data + row * kvBlockSize; + fill_stub(row_ptr + last_col + 1, + -std::numeric_limits::infinity(), + kvBlockSize - last_col - 1); + } + } + // Update attention weights with attention mask + // And apply scaling factor + // qk <- qk * scaling + attn_mask + if (has_attn_mask) { + for (int64_t row = 0; row < qBlockSize; ++row) { + if (mStrideN == 0) { + _scale_dequant_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor * q_scale * k_scale); + } else { + _scale_dequant_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM + n, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor * q_scale * k_scale); + } + } + } + // Update coefficients with Softmax + accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0; + for (int64_t row = 0; row < qBlockSize; ++row) { + if (has_attn_mask) { + // max per row + tmp_max = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, + qk_data + row * kvBlockSize, + kvBlockSize); + } else { + // apply scaling factor and max per row in fusion + _mul_reduce_max_fusion_kernel( + qk_data + row * kvBlockSize, + scaling_factor * q_scale * k_scale, + kvBlockSize, + qk_data + row * kvBlockSize, + tmp_max); + } + tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; + if (tmp_max == -std::numeric_limits::infinity()) { + // to avoid `nan = exp2f(-inf - (-inf))` + fill_stub(qk_reduced_data + row * ekvBlockSize, + static_cast(0), kvBlockSize); + } else { + tmp_sum = tmp_max; + // qk <- exp(qk - max) and sum per row + _fp8_exp_reduce_sum_quant_fusion_kernel( + qk_data + row * kvBlockSize, kvBlockSize, + qk_reduced_data + row * ekvBlockSize, + tmp_sum, + 1.0 / a_scale); + // exp_tmp <- exp(max[row] - max) + exp_tmp = std::exp(qk_max_data[row] - tmp_max); + // sum[row] <- sum + exp_tmp * sum[row] + qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row]; + // max[row] <- max + qk_max_data[row] = tmp_max; + // dst <- dst * exp_tmp + if (n > 0) { + at::vec::map( + [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, + dst_data + row * headSize, + dst_data + row * headSize, + headSize); + } + } + if (kvBlockSize % 4 != 0) { + // Pad: [qSplitSize, kvBlockSize] -> [qSplitSize, kvBlockSize + 4 - kvBlockSize / 4] + for (int64_t psize = kvBlockSize; psize < ekvBlockSize; ++psize) { + *(qk_reduced_data + row * ekvBlockSize + psize) = scalar_t(0); + } + } + } + // Calculate Softmax(q @ k.T) @ v + int64_t psize = n / kvSplitSize * ekvSplitSize; + at::native::cpublas::brgemm( + qBlockSize, + headSize, + ekvBlockSize, + ekvBlockSize, + headSize, + headSize, + n > 0, + qk_reduced_data, + value_reorder_ptr + + i * num_head * kv_padding_size * headSize + + j * kv_padding_size * headSize + psize * headSize, + dst_data); + } + + // dst <- dst / sum[row] + // reorder MHA output with strides + for (int64_t row = 0; row < qBlockSize; ++row) { + // Row sums for full masked out rows are 0, we set them to 1 + // in order to avoid NaNs in the output and instead set fully + // masked out rows to 0 + qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row]; + qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row]; + accum_t sum_reciprocal = 1 / qk_sum_data[row]; + _fp8_dequant_quant_fusion_kernel( + dst_data + row * headSize, + headSize, + out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, + sum_reciprocal * a_scale * v_scale / o_scale); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); + } + at::native::cpublas::brgemm_release(); + }); +} template inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( +int8_sdpa_fused_kernel_impl( bool use_one_parallel_loop, const at::Tensor& output, const at::Tensor& query, @@ -1646,7 +2160,7 @@ sdpa_int8_fused_kernel_impl( float o_scale, int32_t o_zp) { if (use_one_parallel_loop) { - sdpa_int8_fused_kernel_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1656,7 +2170,7 @@ sdpa_int8_fused_kernel_impl( a_scale, a_zp, o_scale, o_zp); } else { - sdpa_int8_fused_kernel_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1668,7 +2182,6 @@ sdpa_int8_fused_kernel_impl( } } - #define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, \ @@ -1684,7 +2197,7 @@ sdpa_int8_fused_kernel_impl( AT_PRIVATE_CASE_TYPE_USING_HINT( \ at::ScalarType::Half, mask_t, __VA_ARGS__)) -void sdpa_int8_fused_kernel( +void int8_sdpa_fused_kernel( const at::Tensor& output, const at::Tensor& query, const at::Tensor& key, @@ -1724,7 +2237,7 @@ void sdpa_int8_fused_kernel( (attn_size > 1.5 * l2_cache_size); if (!attn_mask.has_value()) { if (q_split_size == 256) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1734,7 +2247,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else if (q_split_size == 64) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1744,7 +2257,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1757,7 +2270,7 @@ void sdpa_int8_fused_kernel( } else { AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { if (q_split_size == 256) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1767,7 +2280,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else if (q_split_size == 64) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1777,7 +2290,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1790,9 +2303,86 @@ void sdpa_int8_fused_kernel( }); } } + +void fp8_sdpa_fused_kernel( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + std::optional scale, + float q_scale, + float k_scale, + float v_scale, + float a_scale, + float o_scale) { + TORCH_CHECK(query.scalar_type() == c10::kFloat8_e4m3fn); + int64_t batchSize = query.size(0); + int64_t num_head = query.size(1); + int64_t q_seq_len = query.size(2); + int64_t kv_seq_len = key.size(2); + int64_t q_split_size = 32; + if (q_seq_len >= 768) { + q_split_size = 256; + } else if (q_seq_len >= 192) { + q_split_size = 64; + } + + if (!attn_mask.has_value()) { + if (q_split_size == 256) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else if (q_split_size == 64) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { + if (q_split_size == 256) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else if (q_split_size == 64) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } + }); + } +} #endif // CPU_CAPABILITY_AVX512 -at::Tensor sdpa_int8_math_kernel( +at::Tensor int8_sdpa_math_kernel( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, @@ -1834,6 +2424,43 @@ at::Tensor sdpa_int8_math_kernel( return output; } +at::Tensor fp8_sdpa_math_kernel( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + std::optional scale, + float q_scale, + float k_scale, + float v_scale, + float a_scale, + float o_scale) { + // dequant q/k/v + auto q = query.to(at::kFloat) * q_scale; + auto k = key.to(at::kFloat) * k_scale; + auto v = value.to(at::kFloat) * v_scale; + const auto scaling_factor = calculate_scale(q, scale); + auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; + if (attn_mask.has_value() && attn_mask.value().numel()) { + attn = attn.add(attn_mask.value().to(at::kFloat)); + } + attn = at::softmax(attn, -1); + // quant attn + attn = at::clamp_max( + at::clamp_min(attn / a_scale, -448), 448 + ); + attn = attn.to(at::kFloat8_e4m3fn).to(at::kFloat); + // dequant attn + attn = attn * a_scale; + auto output = at::matmul(attn, v); + // quant output + output = at::clamp_max( + at::clamp_min(output / o_scale, -448), 448 + ).to(at::kFloat8_e4m3fn); + return output; +} at::Tensor _qscaled_dot_product_cpu( const at::Tensor& query, @@ -1858,8 +2485,8 @@ at::Tensor _qscaled_dot_product_cpu( "_qscaled_dot_product_cpu: Only accept plain inputs"); TORCH_CHECK(!is_causal, "_qscaled_dot_product_cpu: is_causal not supported."); - TORCH_CHECK(dtype == at::ScalarType::Byte, - "_qscaled_dot_product_cpu: Expected data type be U8, but got ", dtype, " instead."); + TORCH_CHECK(dtype == at::ScalarType::Byte || dtype == at::ScalarType::Float8_e4m3fn, + "_qscaled_dot_product_cpu: Expected data type be U8 or Float8_e4m3, but got ", dtype, " instead."); TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, "_qscaled_dot_product_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); TORCH_CHECK(dropout_p == 0.0, @@ -1873,30 +2500,63 @@ at::Tensor _qscaled_dot_product_cpu( TORCH_CHECK(!attn_mask.has_value() || (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4), "_qscaled_dot_product_cpu: Attention mask dim in {2, 4}"); + if (dtype == at::ScalarType::Float8_e4m3fn) { + TORCH_CHECK(q_zp == 0 && k_zp == 0 && v_zp == 0 && a_zp == 0 && o_zp == 0, + "_qscaled_dot_product_cpu: Don't accept zero point for Float8_e4m3"); + } - #ifdef CPU_CAPABILITY_AVX512 - if (at::native::cpublas::could_pack(dtype)) { - at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); - sdpa_int8_fused_kernel(output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - return output.transpose(1, 2); - } else { - #endif // CPU_CAPABILITY_AVX512 - return sdpa_int8_math_kernel(query, key, value, + if (dtype == at::ScalarType::Byte) { +#ifdef CPU_CAPABILITY_AVX512 + if (at::native::cpublas::could_pack(dtype)) { + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + std::cout << "int8_sdpa_fused_kernel" << std::endl; + int8_sdpa_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + return output.transpose(1, 2); + } else { +#endif // CPU_CAPABILITY_AVX512 + std::cout << "int8_sdpa_math_kernel" << std::endl; + return int8_sdpa_math_kernel(query, key, value, dropout_p, is_causal, attn_mask, scale, q_scale, q_zp, k_scale, k_zp, v_scale, v_zp, a_scale, a_zp, o_scale, o_zp).transpose(1, 2).contiguous().transpose(1, 2); - #ifdef CPU_CAPABILITY_AVX512 - } - #endif // CPU_CAPABILITY_AVX512 +#ifdef CPU_CAPABILITY_AVX512 + } +#endif // CPU_CAPABILITY_AVX512 + } else if (dtype == at::ScalarType::Float8_e4m3fn) { +#if defined(CPUBLAS_BRGEMM_F8F8F32) && defined(CPU_CAPABILITY_AVX512) +// CPUBLAS_BRGEMM_F8F8F32 is defined if FP8 BRGEMM is supported in PyTorch CPUBlas. + if (at::native::cpublas::could_pack(dtype)) { + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + std::cout << "fp8_sdpa_fused_kernel" << std::endl; + fp8_sdpa_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + return output.transpose(1, 2); + } else { +#endif // CPU_CAPABILITY_AVX512 && CPUBLAS_BRGEMM_F8F8F32 + std::cout << "fp8_sdpa_math_kernel" << std::endl; + return fp8_sdpa_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale).transpose(1, 2).contiguous().transpose(1, 2); +#if defined(CPUBLAS_BRGEMM_F8F8F32) && defined(CPU_CAPABILITY_AVX512) + } +#endif // CPU_CAPABILITY_AVX512 && CPUBLAS_BRGEMM_F8F8F32 + } else { + TORCH_CHECK(false, "_qscaled_dot_product_cpu: Unsupported data type ", dtype); + } } diff --git a/torchao/prototype/inductor/fx_passes/README.md b/torchao/prototype/inductor/fx_passes/README.md index 7007aba993..fe4939a314 100644 --- a/torchao/prototype/inductor/fx_passes/README.md +++ b/torchao/prototype/inductor/fx_passes/README.md @@ -11,7 +11,7 @@ In TorchAO, you can replace the following customized graph passes of Inductor: ## Directory Structure -- `int8_sdpa_fusion`: Pattern match for int8 sdpa fusion. +- `qsdpa_fusion`: Pattern match for qsdpa fusion. ## Getting Started diff --git a/torchao/prototype/inductor/fx_passes/__init__.py b/torchao/prototype/inductor/fx_passes/__init__.py index 7ba311bf41..eff7ff1dc2 100644 --- a/torchao/prototype/inductor/fx_passes/__init__.py +++ b/torchao/prototype/inductor/fx_passes/__init__.py @@ -1,7 +1,7 @@ from .da8w4_concat_linear_fusion_cpu import register_da8w4_concat_linear_cpu_pass -from .int8_sdpa_fusion import _int8_sdpa_init +from .qsdpa_fusion import _qsdpa_init __all__ = [ - "_int8_sdpa_init", + "_qsdpa_init", "register_da8w4_concat_linear_cpu_pass", ] diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py deleted file mode 100644 index 5e032f01c2..0000000000 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ /dev/null @@ -1,396 +0,0 @@ -import functools -import itertools - -import torch -from torch._dynamo.utils import counters -from torch._inductor import config -from torch._inductor.lowering import lowerings as L -from torch._inductor.lowering import make_fallback -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - KeywordArg, - Match, - PatternMatcherPass, - register_lowering_pattern, -) - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_7: - # TORCH_VERSION_AT_LEAST_2_7 is needed for functions in int8 sdpa lowering - from ..int8_sdpa_lowering import register_int8_sdpa # noqa: F401 -else: - make_fallback(torch.ops.torchao.qscaled_dot_product.default) - -__all__ = [ - "_int8_sdpa_init", -] - -aten = torch.ops.aten - - -def _is_valid_int8_sdpa_pattern(): - def fn(match): - assert all(k in match.kwargs for k in ("query", "key", "value")) - query = match.kwargs["query"].meta["val"] - key = match.kwargs["key"].meta["val"] - value = match.kwargs["value"].meta["val"] - return ( - query.dtype == torch.uint8 - and key.dtype == torch.uint8 - and value.dtype == torch.uint8 - and query.device.type == "cpu" - and key.device == query.device - and value.device == query.device - ) - - return fn - - -def _register_int8_sdpa_pattern(pattern, custom_pass_dict): - @register_lowering_pattern( - pattern, extra_check=_is_valid_int8_sdpa_pattern(), pass_dict=custom_pass_dict - ) - def int8_sdpa(match: Match, *args, **kwargs): - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - scale = 1.0 / kwargs["inv_scale"] if "inv_scale" in kwargs else None - attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None - q_scale = kwargs["q_scale"] - q_zp = kwargs["q_zp"] - k_scale = kwargs["k_scale"] - k_zp = kwargs["k_zp"] - v_scale = kwargs["v_scale"] - v_zp = kwargs["v_zp"] - a_scale = kwargs["a_scale"] - a_zp = kwargs["a_zp"] - o_scale = kwargs["o_scale"] - o_zp = kwargs["o_zp"] - counters["inductor"]["int8_fuse_attention"] += 1 - counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) - - trans_query = L[aten.permute.default](query, [0, 2, 1, 3]) - trans_key = L[aten.permute.default](key, [0, 2, 1, 3]) - trans_value = L[aten.permute.default](value, [0, 2, 1, 3]) - output = L[torch.ops.torchao.qscaled_dot_product.default]( - trans_query, - trans_key, - trans_value, - attn_mask, - 0.0, # dropout - False, # is_causal - scale, # scale - q_scale, - q_zp, - k_scale, - k_zp, - v_scale, - v_zp, - a_scale, - a_zp, - o_scale, - o_zp, - ) - trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) - return L[aten.clone.default]( - trans_output, memory_format=torch.contiguous_format - ) - - return int8_sdpa - - -def _get_int8_sdpa_qkv_pattern( - is_batch_size_1: bool, has_convert: bool, input_name: str -): - assert input_name in ["query", "key", "value"] - int8_sdpa_qkv_pattern_before_dequant = CallFunction( - aten.permute.default, - KeywordArg(input_name), - Arg(), - ) - if input_name == "key": - # do transpose - int8_sdpa_qkv_pattern_before_dequant = CallFunction( - aten.permute.default, - int8_sdpa_qkv_pattern_before_dequant, - Arg(), - ) - int8_sdpa_qkv_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - int8_sdpa_qkv_pattern_before_dequant, - KeywordArg(input_name[0] + "_scale"), - KeywordArg(input_name[0] + "_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_qkv_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - int8_sdpa_qkv_basic_pattern = CallFunction( - aten.expand.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, - CallFunction( - aten.clone.default, - int8_sdpa_qkv_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_score_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_q_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "query" - ) - int8_sdpa_k_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "key" - ) - int8_sdpa_score_basic_pattern = CallFunction( - aten.reshape.default, - CallFunction( - aten.bmm.default, - int8_sdpa_q_pattern, - int8_sdpa_k_pattern, - ), - Arg(), - ) - if is_reduced_type and not has_mask: - int8_sdpa_score_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_score_basic_pattern, - Arg(), - ) - if has_mask: - return CallFunction( - aten.add.Tensor, - CallFunction( - aten.div.Tensor, - int8_sdpa_score_basic_pattern, - KeywordArg("inv_scale"), - ), - KeywordArg("attn_mask"), - _users=2, - ) - else: - return CallFunction( - aten.mul.Tensor, - int8_sdpa_score_basic_pattern, - Arg(), - _users=2, - ) - - -def _get_int8_sdpa_exp_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_exp_basic_pattern = CallFunction( - aten.sub.Tensor, - int8_sdpa_score_pattern, - CallFunction( - aten.amax.default, - int8_sdpa_score_pattern, - Arg(), - Arg(), - ), - ) - if has_mask: - return CallFunction( - aten.exp.default, - int8_sdpa_exp_basic_pattern, - _users=2, - ) - else: - return CallFunction( - aten.exp.default, - CallFunction( - aten.div.Tensor, - int8_sdpa_exp_basic_pattern, - KeywordArg("inv_scale"), - ), - _users=2, - ) - - -def _get_int8_sdpa_attn_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_div_pattern = CallFunction( - aten.div.Tensor, - int8_sdpa_exp_pattern, - CallFunction( - aten.sum.dim_IntList, - int8_sdpa_exp_pattern, - Arg(), - Arg(), - ), - ) - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - int8_sdpa_div_pattern, - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if is_reduced_type: - if has_mask: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - else: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_div_pattern, - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - return CallFunction( - aten.reshape.default, - CallFunction( - aten.expand.default, - int8_sdpa_softmax_pattern, - Arg(), - ), - Arg(), - ) - - -# Parameters to generate various patterns: -# has_mask: if SDPA has attention mask -# is_batch_size_1: if the batch size is 1 -# is_reduced_type: if autocast is enabled -# has_convert: convert type if dequant out dtype is assigned -def _get_int8_sdpa_final_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_v_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "value" - ) - int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - return CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - CallFunction( - aten.clone.default, - CallFunction( - aten.permute.default, - CallFunction( - aten.reshape.default, - CallFunction( - aten.bmm.default, - int8_sdpa_attn_pattern, - int8_sdpa_v_pattern, - ), - Arg(), - ), - Arg(), - ), - memory_format=Arg(), - ), - KeywordArg("o_scale"), - KeywordArg("o_zp"), - Arg(), - Arg(), - Arg(), - ) - - -def _register_int8_sdpa_lowerings(custom_pass_dict): - for has_mask, is_batch_size_1, is_reduced_type, has_convert in itertools.product( - [True, False], [True, False], [True, False], [True, False] - ): - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=has_mask, - is_batch_size_1=is_batch_size_1, - is_reduced_type=is_reduced_type, - has_convert=has_convert, - ), - custom_pass_dict, - ) - - -custom_pass = None -if TORCH_VERSION_AT_LEAST_2_7: - # TORCH_VERSION_AT_LEAST_2_7 is needed for custom graph pass - from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files - - # define the custom pass - class _CustomPass(PatternMatcherPass, CustomGraphPass): - def __init__(self) -> None: - super().__init__() - - def __call__(self, g: torch.fx.graph.Graph): - self.apply(g) - - def uuid(self) -> bytes: - return get_hash_for_files((__file__,)) - - custom_pass = _CustomPass() - - -@functools.lru_cache(None) -def _int8_sdpa_init(): - if TORCH_VERSION_AT_LEAST_2_7: - _register_int8_sdpa_lowerings(config.post_grad_custom_pre_pass) - else: - pass diff --git a/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py new file mode 100644 index 0000000000..2d3ef82efa --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py @@ -0,0 +1,472 @@ +import functools +import itertools + +import torch +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.lowering import lowerings as L +from torch._inductor.lowering import make_fallback +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + KeywordArg, + Match, + PatternMatcherPass, + register_lowering_pattern, +) + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 + +if TORCH_VERSION_AT_LEAST_2_7: + # TORCH_VERSION_AT_LEAST_2_7 is needed for functions in qsdpa lowering + from ..qsdpa_lowering import register_qsdpa # noqa: F401 +else: + make_fallback(torch.ops.torchao.qscaled_dot_product.default) + +__all__ = [ + "_qsdpa_init", +] + +aten = torch.ops.aten +quantize_dtypes = [torch.uint8, torch.float8_e4m3fn] + +def _is_valid_qsdpa_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + return ( + query.dtype in quantize_dtypes + and key.dtype in quantize_dtypes + and value.dtype in quantize_dtypes + and query.device.type == "cpu" + and key.device == query.device + and value.device == query.device + ) + + return fn + + +def _register_qsdpa_pattern(pattern, custom_pass_dict): + @register_lowering_pattern( + pattern, extra_check=_is_valid_qsdpa_pattern(), pass_dict=custom_pass_dict + ) + def qsdpa(match: Match, *args, **kwargs): + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + scale = 1.0 / kwargs["inv_scale"] if "inv_scale" in kwargs else None + if scale is None: + scale = kwargs["scale"] if "scale" in kwargs else None + attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None + q_zp = 0 + k_zp = 0 + v_zp = 0 + a_zp = 0 + o_zp = 0 + if query.dtype == torch.uint8: + q_scale = kwargs["q_scale"] + q_zp = kwargs["q_zp"] + k_scale = kwargs["k_scale"] + k_zp = kwargs["k_zp"] + v_scale = kwargs["v_scale"] + v_zp = kwargs["v_zp"] + a_scale = kwargs["a_scale"] + a_zp = kwargs["a_zp"] + o_scale = kwargs["o_scale"] + o_zp = kwargs["o_zp"] + else: + assert match.kwargs["q_scale"].target == aten.full.default + q_scale = match.kwargs["q_scale"].args[1] + k_scale = match.kwargs["k_scale"].args[1] + v_scale = match.kwargs["v_scale"].args[1] + a_scale = match.kwargs["a_scale"].args[1] + o_scale = match.kwargs["o_scale"].args[1] + + counters["inductor"]["qsdpa_fuse_attention"] += 1 + counters["inductor"]["qsdpa_nodes"] += len(match.nodes) + + trans_query = L[aten.permute.default](query, [0, 2, 1, 3]) + trans_key = L[aten.permute.default](key, [0, 2, 1, 3]) + trans_value = L[aten.permute.default](value, [0, 2, 1, 3]) + output = L[torch.ops.torchao.qscaled_dot_product.default]( + trans_query, + trans_key, + trans_value, + attn_mask, + 0.0, # dropout + False, # is_causal + scale, + q_scale, + q_zp, + k_scale, + k_zp, + v_scale, + v_zp, + a_scale, + a_zp, + o_scale, + o_zp, + ) + trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) + return L[aten.clone.default]( + trans_output, memory_format=torch.contiguous_format + ) + + return qsdpa + + +def _generate_dequant_pattern(input_pattern, qtype, is_reduced_type, scale: str, zp: str=None): + if qtype == torch.uint8: + assert zp is not None, "Zero point must be provided for uint8 dequantization" + return CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + input_pattern, + KeywordArg(scale), + KeywordArg(zp), + Arg(), + Arg(), + Arg(), + ) + else: + assert zp is None, "Fp8 dequantization does not support zero point" + if is_reduced_type: + return CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + input_pattern, + KeywordArg(scale), + Arg(), + ) + else: + return CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + input_pattern, + KeywordArg(scale), + ) + + +def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str=None): + if qtype == torch.uint8: + assert zp is not None, "Zero point must be provided for uint8 quantization" + return CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + input_pattern, + KeywordArg(scale), + KeywordArg(zp), + Arg(), + Arg(), + Arg(), + ) + else: + assert zp is None, "Fp8 quantization does not support zero point" + return CallFunction( + torch.ops.torchao.quantize_affine_float8.default, + input_pattern, + KeywordArg(scale), + ) + + +def _get_qsdpa_qkv_pattern( + qtype, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, input_name: str +): + assert input_name in ["query", "key", "value"] + qsdpa_qkv_pattern_before_dequant = CallFunction( + aten.permute.default, + KeywordArg(input_name), + Arg(), + ) + if input_name == "key": + # do transpose + qsdpa_qkv_pattern_before_dequant = CallFunction( + aten.permute.default, + qsdpa_qkv_pattern_before_dequant, + Arg(), + ) + qsdpa_qkv_basic_pattern = _generate_dequant_pattern( + qsdpa_qkv_pattern_before_dequant, + qtype, + is_reduced_type, + input_name[0] + "_scale", + input_name[0] + "_zp" if qtype is torch.uint8 else None, + ) + if has_convert: + qsdpa_qkv_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_qkv_basic_pattern, + Arg(), + ) + qsdpa_qkv_basic_pattern = CallFunction( + aten.expand.default, + qsdpa_qkv_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, + qsdpa_qkv_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, + CallFunction( + aten.clone.default, + qsdpa_qkv_basic_pattern, + memory_format=Arg(), + ), + Arg(), + ) + + +def _get_qsdpa_score_pattern( + qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool +): + qsdpa_q_pattern = _get_qsdpa_qkv_pattern( + qtype, is_batch_size_1, is_reduced_type, has_convert, "query" + ) + qsdpa_k_pattern = _get_qsdpa_qkv_pattern( + qtype, is_batch_size_1, is_reduced_type, has_convert, "key" + ) + qsdpa_score_basic_pattern = CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + qsdpa_q_pattern, + qsdpa_k_pattern, + ), + Arg(), + ) + if is_reduced_type and not has_mask: + qsdpa_score_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_score_basic_pattern, + Arg(), + ) + if not has_mask: + return CallFunction( + aten.mul.Tensor, + qsdpa_score_basic_pattern, + Arg(), + _users=2, + ) + elif is_inv_scale: + return CallFunction( + aten.add.Tensor, + CallFunction( + aten.div.Tensor, + qsdpa_score_basic_pattern, + KeywordArg("inv_scale"), + ), + KeywordArg("attn_mask"), + _users=2, + ) + else: + return CallFunction( + aten.add.Tensor, + CallFunction( + aten.mul.Tensor, + qsdpa_score_basic_pattern, + KeywordArg("scale"), + ), + KeywordArg("attn_mask"), + _users=2, + ) + + +def _get_qsdpa_exp_pattern( + qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool +): + qsdpa_score_pattern = _get_qsdpa_score_pattern( + qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale + ) + qsdpa_exp_basic_pattern = CallFunction( + aten.sub.Tensor, + qsdpa_score_pattern, + CallFunction( + aten.amax.default, + qsdpa_score_pattern, + Arg(), + Arg(), + ), + ) + if has_mask: + return CallFunction( + aten.exp.default, + qsdpa_exp_basic_pattern, + _users=2, + ) + elif is_inv_scale: + return CallFunction( + aten.exp.default, + CallFunction( + aten.div.Tensor, + qsdpa_exp_basic_pattern, + KeywordArg("inv_scale"), + ), + _users=2, + ) + else: + return CallFunction( + aten.exp.default, + CallFunction( + aten.mul.Tensor, + qsdpa_exp_basic_pattern, + KeywordArg("scale"), + ), + _users=2, + ) + + +def _get_qsdpa_attn_pattern( + qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool +): + qsdpa_exp_pattern = _get_qsdpa_exp_pattern( + qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale + ) + qsdpa_div_pattern = CallFunction( + aten.div.Tensor, + qsdpa_exp_pattern, + CallFunction( + aten.sum.dim_IntList, + qsdpa_exp_pattern, + Arg(), + Arg(), + ), + ) + qsdpa_softmax_pattern = _generate_dequant_pattern( + _generate_quant_pattern( + qsdpa_div_pattern, + qtype, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ), + qtype, + is_reduced_type, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ) + if is_reduced_type: + if has_mask: + qsdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_softmax_pattern, + Arg(), + ) + else: + qsdpa_softmax_pattern = _generate_dequant_pattern( + _generate_quant_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_div_pattern, + Arg(), + ), + qtype, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ), + qtype, + is_reduced_type, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ) + if has_convert: + qsdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_softmax_pattern, + Arg(), + ) + return CallFunction( + aten.reshape.default, + CallFunction( + aten.expand.default, + qsdpa_softmax_pattern, + Arg(), + ), + Arg(), + ) + + +# Parameters to generate various patterns: +# qdtype: quantized dtypes are uint8, float8_e4m3fn for now +# has_mask: if SDPA has attention mask +# is_batch_size_1: if the batch size is 1 +# is_reduced_type: if autocast is enabled +# has_convert: convert type if dequant out dtype is assigned +# is_inv_scale: if the scale in SDPA is inversed, in which case it is multiplied instead of divided +def _get_qsdpa_final_pattern( + qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool +): + qsdpa_v_pattern = _get_qsdpa_qkv_pattern( + qtype, is_batch_size_1, is_reduced_type, has_convert, "value" + ) + qsdpa_attn_pattern = _get_qsdpa_attn_pattern( + qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale + ) + return _generate_quant_pattern( + CallFunction( + aten.clone.default, + CallFunction( + aten.permute.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + qsdpa_attn_pattern, + qsdpa_v_pattern, + ), + Arg(), + ), + Arg(), + ), + memory_format=Arg(), + ), + qtype, + "o_scale", + "o_zp" if qtype is torch.uint8 else None, + ) + + +def _register_qsdpa_lowerings(custom_pass_dict): + for qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale in itertools.product( + quantize_dtypes, [True, False], [True, False], [True, False], [True, False], [True, False] + ): + _register_qsdpa_pattern( + _get_qsdpa_final_pattern( + qtype=qtype, + has_mask=has_mask, + is_batch_size_1=is_batch_size_1, + is_reduced_type=is_reduced_type, + has_convert=has_convert, + is_inv_scale=is_inv_scale, + ), + custom_pass_dict, + ) + + +custom_pass = None +if TORCH_VERSION_AT_LEAST_2_7: + # TORCH_VERSION_AT_LEAST_2_7 is needed for custom graph pass + from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files + + # define the custom pass + class _CustomPass(PatternMatcherPass, CustomGraphPass): + def __init__(self) -> None: + super().__init__() + + def __call__(self, g: torch.fx.graph.Graph): + self.apply(g) + + def uuid(self) -> bytes: + return get_hash_for_files((__file__,)) + + custom_pass = _CustomPass() + + +@functools.lru_cache(None) +def _qsdpa_init(): + if TORCH_VERSION_AT_LEAST_2_7: + _register_qsdpa_lowerings(config.post_grad_custom_pre_pass) + else: + pass diff --git a/torchao/prototype/inductor/int8_sdpa_lowering.py b/torchao/prototype/inductor/qsdpa_lowering.py similarity index 67% rename from torchao/prototype/inductor/int8_sdpa_lowering.py rename to torchao/prototype/inductor/qsdpa_lowering.py index be989adb33..725ca19dea 100644 --- a/torchao/prototype/inductor/int8_sdpa_lowering.py +++ b/torchao/prototype/inductor/qsdpa_lowering.py @@ -1,70 +1,31 @@ -from collections.abc import Sequence from typing import Optional import sympy import torch from torch._inductor.ir import ChoiceCaller, FixedLayout, TensorBox, get_fill_order +from torch._inductor.kernel.flex_attention import construct_strides, maybe_realize from torch._inductor.lowering import register_lowering from torch._inductor.select_algorithm import ( ExternKernelChoice, autotune_select_algorithm, - realize_inputs, ) -from torch.utils._pytree import tree_map from .codegen.cpp_int8_sdpa_template import CppInt8SdpaTemplate - -# Copied directly from https://github.com/pytorch/pytorch/commit/e221a1c853b425b8d70b36d545ccb32ddc8176bd -def maybe_realize(args): - """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" - return tree_map( - lambda x: ( - realize_inputs(x) - if x is not None and not isinstance(x, sympy.Symbol) - else x - ), - args, - ) - - -# Copied directly from https://github.com/pytorch/pytorch/commit/e221a1c853b425b8d70b36d545ccb32ddc8176bd -def construct_strides( - sizes: Sequence[int], - fill_order: Sequence[int], -) -> Sequence[int]: - """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" - # Initialize strides - assert len(sizes) == len(fill_order), ( - "Length of sizes must match the length of the fill order" - ) - strides = [0] * len(sizes) - - # Start with stride 1 for the innermost dimension - current_stride = 1 - - # Iterate through the fill order populating strides - for dim in fill_order: - strides[dim] = current_stride - current_stride *= sizes[dim] - - return strides - - -op_int8_sdpa = ExternKernelChoice( +op_qsdpa = ExternKernelChoice( torch.ops.torchao.qscaled_dot_product.default, "torchao::qscaled_dot_product", has_out_variant=False, use_fallback_kernel=True, op_overload=torch.ops.torchao.qscaled_dot_product.default, ) +quantize_dtypes = [torch.uint8, torch.float8_e4m3fn] - -def register_int8_sdpa(): +def register_qsdpa(): @register_lowering( torch.ops.torchao.qscaled_dot_product.default, type_promotion_kind=None ) - def int8_sdpa( + def qsdpa( query: TensorBox, key: TensorBox, value: TensorBox, @@ -100,12 +61,12 @@ def int8_sdpa( ) if ( - query.get_dtype() is not torch.uint8 - or key.get_dtype() is not torch.uint8 - or value.get_dtype() is not torch.uint8 + query.get_dtype() not in quantize_dtypes + or key.get_dtype() not in quantize_dtypes + or value.get_dtype() not in quantize_dtypes ): raise NotImplementedError( - "Only `torch.uint8` is supported in Int8 SDPA template for CPU device. " + "Only `torch.uint8` or `torch.float8_e4m3fn` is supported in Quantized SDPA template for CPU device. " f"Found input tensors are `{query.get_dtype()}`,`{key.get_dtype()}`,`{value.get_dtype()}`." ) @@ -124,8 +85,8 @@ def int8_sdpa( if attn_mask is not None: input_nodes.append(attn_mask) - # use template if machine has amx - if torch._C._cpu._is_amx_tile_supported(): + # use template if machine has amx, only support uint8 for now + if torch._C._cpu._is_amx_tile_supported() and query.get_dtype() is torch.uint8: CppInt8SdpaTemplate.add_choices( choices=choices, input_nodes=input_nodes, @@ -145,7 +106,7 @@ def int8_sdpa( if len(choices) == 0: choices.append( - op_int8_sdpa.bind( + op_qsdpa.bind( input_nodes=input_nodes, layout=layout, scale=scale, @@ -169,11 +130,11 @@ def int8_sdpa( ] return autotune_select_algorithm( - "int8_sdpa", + "qsdpa", choices, inputs_for_autotuning, layout, ) -register_int8_sdpa() +register_qsdpa() From b87509e8febf7740d73ab87133922731304865a1 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 5 Aug 2025 07:16:51 +0000 Subject: [PATCH 2/4] fix format --- .../inductor/fx_passes/qsdpa_fusion.py | 75 ++++++++++++++----- torchao/prototype/inductor/qsdpa_lowering.py | 1 + 2 files changed, 58 insertions(+), 18 deletions(-) diff --git a/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py index 2d3ef82efa..91add22bb2 100644 --- a/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py @@ -30,6 +30,7 @@ aten = torch.ops.aten quantize_dtypes = [torch.uint8, torch.float8_e4m3fn] + def _is_valid_qsdpa_pattern(): def fn(match): assert all(k in match.kwargs for k in ("query", "key", "value")) @@ -117,7 +118,9 @@ def qsdpa(match: Match, *args, **kwargs): return qsdpa -def _generate_dequant_pattern(input_pattern, qtype, is_reduced_type, scale: str, zp: str=None): +def _generate_dequant_pattern( + input_pattern, qtype, is_reduced_type, scale: str, zp: str = None +): if qtype == torch.uint8: assert zp is not None, "Zero point must be provided for uint8 dequantization" return CallFunction( @@ -146,7 +149,7 @@ def _generate_dequant_pattern(input_pattern, qtype, is_reduced_type, scale: str, ) -def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str=None): +def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str = None): if qtype == torch.uint8: assert zp is not None, "Zero point must be provided for uint8 quantization" return CallFunction( @@ -168,7 +171,11 @@ def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str=None): def _get_qsdpa_qkv_pattern( - qtype, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, input_name: str + qtype, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + input_name: str, ): assert input_name in ["query", "key", "value"] qsdpa_qkv_pattern_before_dequant = CallFunction( @@ -221,7 +228,12 @@ def _get_qsdpa_qkv_pattern( def _get_qsdpa_score_pattern( - qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, ): qsdpa_q_pattern = _get_qsdpa_qkv_pattern( qtype, is_batch_size_1, is_reduced_type, has_convert, "query" @@ -276,7 +288,12 @@ def _get_qsdpa_score_pattern( def _get_qsdpa_exp_pattern( - qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, ): qsdpa_score_pattern = _get_qsdpa_score_pattern( qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale @@ -298,15 +315,15 @@ def _get_qsdpa_exp_pattern( _users=2, ) elif is_inv_scale: - return CallFunction( - aten.exp.default, - CallFunction( - aten.div.Tensor, - qsdpa_exp_basic_pattern, - KeywordArg("inv_scale"), - ), - _users=2, - ) + return CallFunction( + aten.exp.default, + CallFunction( + aten.div.Tensor, + qsdpa_exp_basic_pattern, + KeywordArg("inv_scale"), + ), + _users=2, + ) else: return CallFunction( aten.exp.default, @@ -320,7 +337,12 @@ def _get_qsdpa_exp_pattern( def _get_qsdpa_attn_pattern( - qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, ): qsdpa_exp_pattern = _get_qsdpa_exp_pattern( qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale @@ -396,7 +418,12 @@ def _get_qsdpa_attn_pattern( # has_convert: convert type if dequant out dtype is assigned # is_inv_scale: if the scale in SDPA is inversed, in which case it is multiplied instead of divided def _get_qsdpa_final_pattern( - qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, ): qsdpa_v_pattern = _get_qsdpa_qkv_pattern( qtype, is_batch_size_1, is_reduced_type, has_convert, "value" @@ -429,8 +456,20 @@ def _get_qsdpa_final_pattern( def _register_qsdpa_lowerings(custom_pass_dict): - for qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale in itertools.product( - quantize_dtypes, [True, False], [True, False], [True, False], [True, False], [True, False] + for ( + qtype, + has_mask, + is_batch_size_1, + is_reduced_type, + has_convert, + is_inv_scale, + ) in itertools.product( + quantize_dtypes, + [True, False], + [True, False], + [True, False], + [True, False], + [True, False], ): _register_qsdpa_pattern( _get_qsdpa_final_pattern( diff --git a/torchao/prototype/inductor/qsdpa_lowering.py b/torchao/prototype/inductor/qsdpa_lowering.py index 725ca19dea..079d2d98b1 100644 --- a/torchao/prototype/inductor/qsdpa_lowering.py +++ b/torchao/prototype/inductor/qsdpa_lowering.py @@ -21,6 +21,7 @@ ) quantize_dtypes = [torch.uint8, torch.float8_e4m3fn] + def register_qsdpa(): @register_lowering( torch.ops.torchao.qscaled_dot_product.default, type_promotion_kind=None From 4fe4c3bd1174cd0327d9730cd8ca12f7bc807c79 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 5 Aug 2025 07:21:36 +0000 Subject: [PATCH 3/4] fix import --- torchao/prototype/inductor/qsdpa_lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/inductor/qsdpa_lowering.py b/torchao/prototype/inductor/qsdpa_lowering.py index 079d2d98b1..22b30cdb91 100644 --- a/torchao/prototype/inductor/qsdpa_lowering.py +++ b/torchao/prototype/inductor/qsdpa_lowering.py @@ -3,7 +3,7 @@ import sympy import torch from torch._inductor.ir import ChoiceCaller, FixedLayout, TensorBox, get_fill_order -from torch._inductor.kernel.flex_attention import construct_strides, maybe_realize +from torch._inductor.kernel.flex.common import construct_strides, maybe_realize from torch._inductor.lowering import register_lowering from torch._inductor.select_algorithm import ( ExternKernelChoice, From 123fd8c5da875d23446d8dd001aea2d916ebbd57 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 5 Aug 2025 07:30:50 +0000 Subject: [PATCH 4/4] fix import --- torchao/prototype/inductor/qsdpa_lowering.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/inductor/qsdpa_lowering.py b/torchao/prototype/inductor/qsdpa_lowering.py index 22b30cdb91..da6c1af0b4 100644 --- a/torchao/prototype/inductor/qsdpa_lowering.py +++ b/torchao/prototype/inductor/qsdpa_lowering.py @@ -3,7 +3,13 @@ import sympy import torch from torch._inductor.ir import ChoiceCaller, FixedLayout, TensorBox, get_fill_order -from torch._inductor.kernel.flex.common import construct_strides, maybe_realize + +try: + # use the directory after refactor + from torch._inductor.kernel.flex.common import construct_strides, maybe_realize +except ImportError: + # use the old path for compatibility + from torch._inductor.kernel.flex_attention import construct_strides, maybe_realize from torch._inductor.lowering import register_lowering from torch._inductor.select_algorithm import ( ExternKernelChoice,