Skip to content

Commit 836dc8d

Browse files
authored
[Wave] Use reciprocal to compute softcap logits (#674)
This PR uses reciprocals instead of divs to compute the logits faster. Signed-off-by: Harsh Menon <[email protected]>
1 parent 8141d52 commit 836dc8d

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

iree/turbine/kernel/wave/templates/extend_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ def first_loop(
256256
x_j = tkw.permute(inner_acc, target_shape=[H, N_Q, N_KV])
257257
x_j = x_j * layer_scale_reg
258258
if logit_cap > 0:
259-
x_j = logit_cap_reg * tkw.tanh(x_j / logit_cap_reg)
259+
logit_cap_reg_inv = tkw.reciprocal(logit_cap_reg)
260+
x_j = logit_cap_reg * tkw.tanh(x_j * logit_cap_reg_inv)
260261
n_kv_index = tkw.self_index(N_KV, tkl.i32)
261262
mask = tkw.apply_expr(n_kv_index, lambda x: x < N_KV)
262263
mask = tkw.broadcast(mask, target_shape=[N_Q, N_KV])
@@ -308,7 +309,8 @@ def second_loop(
308309
x_j = tkw.permute(inner_acc, target_shape=[H, N_Q, N_KV])
309310
x_j = x_j * layer_scale_reg
310311
if logit_cap > 0:
311-
x_j = logit_cap_reg * tkw.tanh(x_j / logit_cap_reg)
312+
logit_cap_reg_inv = tkw.reciprocal(logit_cap_reg)
313+
x_j = logit_cap_reg * tkw.tanh(x_j * logit_cap_reg_inv)
312314
n_kv_index = tkw.self_index(N_KV, tkl.i32)
313315
mask = tkw.apply_expr(n_kv_index, lambda x: x < N_KV)
314316
mask = tkw.broadcast(mask, target_shape=[N_Q, N_KV])

lit_tests/kernel/wave/attention/extend_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_causal_extend_attention():
184184
# CHECK-COUNT-8: amdgpu.mfma
185185

186186
# softcap/logitcap modifier:
187-
# CHECK-COUNT-2: arith.divf
187+
# CHECK-COUNT-4: arith.mulf
188188
# CHECK-COUNT-2: math.tanh
189189
# CHECK-COUNT-2: arith.mulf
190190

@@ -216,7 +216,7 @@ def test_causal_extend_attention():
216216
# CHECK-COUNT-8: amdgpu.mfma
217217

218218
# softcap/logitcap modifier:
219-
# CHECK-COUNT-2: arith.divf
219+
# CHECK-COUNT-4: arith.mulf
220220
# CHECK-COUNT-2: math.tanh
221221
# CHECK-COUNT-2: arith.mulf
222222

@@ -301,7 +301,7 @@ def test_causal_extend_attention_32x32x8():
301301
# CHECK-COUNT-8: amdgpu.mfma
302302

303303
# softcap/logitcap modifier:
304-
# CHECK-COUNT-1: arith.divf
304+
# CHECK-COUNT-2: arith.mulf
305305
# CHECK-COUNT-1: math.tanh
306306
# CHECK-COUNT-1: arith.mulf
307307

@@ -325,7 +325,7 @@ def test_causal_extend_attention_32x32x8():
325325
# CHECK-COUNT-8: amdgpu.mfma
326326

327327
# softcap/logitcap modifier:
328-
# CHECK-COUNT-1: arith.divf
328+
# CHECK-COUNT-2: arith.mulf
329329
# CHECK-COUNT-1: math.tanh
330330
# CHECK-COUNT-1: arith.mulf
331331

0 commit comments

Comments
 (0)