Skip to content

Commit c5d1d59

Browse files
committed
[cpu][fp8] support fp8 sdpa for cpu
1 parent bc2c83e commit c5d1d59

File tree

8 files changed

+1333
-559
lines changed

8 files changed

+1333
-559
lines changed

test/prototype/inductor/test_int8_sdpa_fusion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from torch.testing._internal.inductor_utils import HAS_CPU
1212

1313
import torchao
14-
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import (
15-
_int8_sdpa_init,
14+
from torchao.prototype.inductor.fx_passes.qsdpa_fusion import (
15+
_qsdpa_init,
1616
custom_pass,
1717
)
1818
from torchao.utils import torch_version_at_least
@@ -120,7 +120,7 @@ def _check_common(
120120
)
121121
source_code = "\n".join(source_code)
122122
if has_fuse_pattern:
123-
self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1)
123+
self.assertGreaterEqual(counters["inductor"]["qsdpa_fuse_attention"], 1)
124124
if contains:
125125
self.assertTrue(
126126
any(
@@ -192,7 +192,7 @@ def _test_sdpa_int8_rewriter(self):
192192
),
193193
config.patch(post_grad_custom_pre_pass=custom_pass),
194194
):
195-
_int8_sdpa_init()
195+
_qsdpa_init()
196196
quantizer = X86InductorQuantizer()
197197
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
198198
quantizer.set_function_type_qconfig(

test/test_ops.py

Lines changed: 140 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -154,51 +154,101 @@ def _scaled_dot_product_int8_op_ref(
154154
out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255)
155155
return out.to(torch.uint8)
156156

157+
def _scaled_dot_product_fp8_op_ref(
158+
self,
159+
q,
160+
k,
161+
v,
162+
attn_mask=None,
163+
dropout_p=0,
164+
is_causal=False,
165+
q_scale=1.0,
166+
k_scale=1.0,
167+
v_scale=1.0,
168+
a_scale=1.0,
169+
o_scale=1.0,
170+
):
171+
q = q.to(torch.float) * q_scale
172+
k = k.to(torch.float) * k_scale
173+
v = v.to(torch.float) * v_scale
174+
scale_factor = 1 / math.sqrt(q.size(-1))
175+
attn = q @ k.transpose(-2, -1)
176+
177+
attn = attn * scale_factor
178+
if attn_mask is not None:
179+
attn = attn + attn_mask.to(torch.float)
180+
attn_max = attn.max(dim=-1, keepdim=True).values
181+
attn = attn - attn_max
182+
attn = torch.exp(attn)
183+
attn_sum = torch.sum(attn, dim=-1, keepdim=True)
184+
attn = attn / attn_sum
185+
attn = torch.clamp(attn / a_scale, min=-448, max=448)
186+
attn = attn.to(torch.float8_e4m3fn).to(torch.float)
187+
attn = attn * a_scale
188+
out = attn @ v
189+
out = torch.clamp(out / o_scale, min=-448, max=448)
190+
return out.to(torch.float8_e4m3fn)
191+
157192
@pytest.mark.skipif(
158193
not torch_version_at_least("2.7.0"),
159-
reason="int8 sdpa requires torch 2.7 or later",
194+
reason="quantized sdpa requires torch 2.7 or later",
160195
)
161196
@pytest.mark.skipif(not IS_LINUX, reason="only support on linux")
162197
@pytest.mark.skipif(
163198
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
164199
reason="cpp kernels not built",
165200
)
201+
@parametrize("input_dtype", [torch.uint8, torch.float8_e4m3fn])
166202
@parametrize("batch_size", [56, 120])
167203
@parametrize("n_head", [2, 16])
168204
@parametrize("q_seq_len", [18, 89])
169205
@parametrize("kv_seq_len", [100, 253])
170206
@parametrize("head_dim", [32, 64])
171207
@parametrize("mask_dtype", [None, torch.float32, torch.bfloat16])
172-
def test_scaled_dot_product_int8_op(
173-
self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype
208+
def test_quantized_scaled_dot_product_op(
209+
self,
210+
input_dtype,
211+
batch_size,
212+
n_head,
213+
q_seq_len,
214+
kv_seq_len,
215+
head_dim,
216+
mask_dtype,
174217
):
175218
torch.manual_seed(1234)
176219
device = "cpu"
177-
q_scale = float(1.7907238006591797)
178-
q_zp = int(127)
179-
k_scale = float(1.8039721250534058)
180-
k_zp = int(125)
181-
v_scale = float(1.839004635810852)
182-
v_zp = int(127)
183-
a_scale = float(0.003919653594493866)
184-
a_zp = int(120)
185-
o_scale = float(1.8191684484481812)
186-
o_zp = int(128)
220+
if input_dtype == torch.uint8:
221+
q_scale = float(1.7907238006591797)
222+
k_scale = float(1.8039721250534058)
223+
v_scale = float(1.839004635810852)
224+
a_scale = float(0.003919653594493866)
225+
o_scale = float(1.8191684484481812)
226+
q_zp = int(127)
227+
k_zp = int(125)
228+
v_zp = int(127)
229+
a_zp = int(120)
230+
o_zp = int(128)
231+
atol, rtol = 1.0, 5e-6
232+
else:
233+
q_scale = float(5.96875)
234+
k_scale = float(5.78125)
235+
v_scale = float(0.98046875)
236+
a_scale = float(4.84375)
237+
o_scale = float(3.171875)
238+
atol, rtol = 0.125, 5e-6
187239
q_shape = [batch_size, q_seq_len, n_head, head_dim]
188240
kv_shape = [batch_size, kv_seq_len, n_head, head_dim]
189241
mask_shape = [batch_size, 1, 1, kv_seq_len]
190-
q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100
191-
k = (
192-
torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
193-
* 100
194-
)
195-
v = (
196-
torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
197-
* 100
198-
)
199-
q = q.to(torch.uint8)
200-
k = k.to(torch.uint8)
201-
v = v.to(torch.uint8)
242+
q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2)
243+
k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
244+
v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
245+
if input_dtype == torch.uint8:
246+
q *= 100
247+
k *= 100
248+
v *= 100
249+
q = q.to(input_dtype)
250+
k = k.to(input_dtype)
251+
v = v.to(input_dtype)
202252
attn_mask = (
203253
torch.randn(mask_shape, dtype=mask_dtype, device=device)
204254
if mask_dtype is not None
@@ -211,44 +261,71 @@ def test_scaled_dot_product_int8_op(
211261
attn_mask.clone() if mask_dtype is not None else None,
212262
)
213263

214-
math_ref = self._scaled_dot_product_int8_op_ref(
215-
q2,
216-
k2,
217-
v2,
218-
attn_mask=attn_mask,
219-
dropout_p=0.0,
220-
is_causal=False,
221-
q_scale=q_scale,
222-
q_zp=q_zp,
223-
k_scale=k_scale,
224-
k_zp=k_zp,
225-
v_scale=v_scale,
226-
v_zp=v_zp,
227-
a_scale=a_scale,
228-
a_zp=a_zp,
229-
o_scale=o_scale,
230-
o_zp=o_zp,
231-
)
232-
actual = torch.ops.torchao.qscaled_dot_product(
233-
q,
234-
k,
235-
v,
236-
attn_mask=attn_mask_2,
237-
dropout_p=0.0,
238-
is_causal=False,
239-
q_scale=q_scale,
240-
q_zp=q_zp,
241-
k_scale=k_scale,
242-
k_zp=k_zp,
243-
v_scale=v_scale,
244-
v_zp=v_zp,
245-
a_scale=a_scale,
246-
a_zp=a_zp,
247-
o_scale=o_scale,
248-
o_zp=o_zp,
249-
)
250-
251-
self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6)
264+
if input_dtype == torch.uint8:
265+
math_ref = self._scaled_dot_product_int8_op_ref(
266+
q2,
267+
k2,
268+
v2,
269+
attn_mask=attn_mask,
270+
dropout_p=0.0,
271+
is_causal=False,
272+
q_scale=q_scale,
273+
q_zp=q_zp,
274+
k_scale=k_scale,
275+
k_zp=k_zp,
276+
v_scale=v_scale,
277+
v_zp=v_zp,
278+
a_scale=a_scale,
279+
a_zp=a_zp,
280+
o_scale=o_scale,
281+
o_zp=o_zp,
282+
)
283+
actual = torch.ops.torchao.qscaled_dot_product(
284+
q,
285+
k,
286+
v,
287+
attn_mask=attn_mask_2,
288+
dropout_p=0.0,
289+
is_causal=False,
290+
q_scale=q_scale,
291+
q_zp=q_zp,
292+
k_scale=k_scale,
293+
k_zp=k_zp,
294+
v_scale=v_scale,
295+
v_zp=v_zp,
296+
a_scale=a_scale,
297+
a_zp=a_zp,
298+
o_scale=o_scale,
299+
o_zp=o_zp,
300+
)
301+
else:
302+
math_ref = self._scaled_dot_product_fp8_op_ref(
303+
q2,
304+
k2,
305+
v2,
306+
attn_mask=attn_mask,
307+
dropout_p=0.0,
308+
is_causal=False,
309+
q_scale=q_scale,
310+
k_scale=k_scale,
311+
v_scale=v_scale,
312+
a_scale=a_scale,
313+
o_scale=o_scale,
314+
)
315+
actual = torch.ops.torchao.qscaled_dot_product(
316+
q,
317+
k,
318+
v,
319+
attn_mask=attn_mask_2,
320+
dropout_p=0.0,
321+
is_causal=False,
322+
q_scale=q_scale,
323+
k_scale=k_scale,
324+
v_scale=v_scale,
325+
a_scale=a_scale,
326+
o_scale=o_scale,
327+
)
328+
self.assertEqual(actual.float(), math_ref.float(), atol=atol, rtol=rtol)
252329

253330

254331
instantiate_parametrized_tests(TestOps)

0 commit comments

Comments
 (0)