Skip to content

Commit 5271aa4

Browse files
authored
[tensor_desc]: Allow make_tensor_descriptor with non unit stride in innermost dimension (#4122)
Upstream code enforces the stride of the `make_tensor_descriptor` operation to be 1. We relax this constraint for XPU. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent fd72cd4 commit 5271aa4

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# pylint: disable=unused-argument
1111
@triton.jit
1212
def _attn_fwd_inner(acc, l_i, m_i, q, #
13-
K_block_ptr, V_desc, #
13+
K_desc, V_desc, #
1414
start_m, qk_scale, #
1515
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, #
1616
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
@@ -24,13 +24,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
2424
# causal = False
2525
else:
2626
lo, hi = 0, N_CTX
27-
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
27+
off_k = lo
2828
off_v = lo
2929
# loop over k, v and update accumulator
3030
for start_n in range(lo, hi, BLOCK_N):
3131
start_n = tl.multiple_of(start_n, BLOCK_N)
3232
# -- compute qk ----
33-
k = tl.load(K_block_ptr)
33+
k = K_desc.load([0, off_k])
3434
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
3535
qk += tl.dot(q, k)
3636
if STAGE == 2:
@@ -54,7 +54,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
5454
# update m_i and l_i
5555
m_i = m_ij
5656
off_v += BLOCK_N
57-
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
57+
off_k += BLOCK_N
5858
return acc, l_i, m_i
5959

6060

@@ -90,9 +90,8 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
9090
block_shape=(BLOCK_M, BLOCK_DMODEL))
9191
V_desc = tl.make_tensor_descriptor(base=V + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn),
9292
block_shape=(BLOCK_N, BLOCK_DMODEL))
93-
#FIXME: change to a tensor descriptor.
94-
K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn),
95-
offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1))
93+
K_desc = tl.make_tensor_descriptor(base=K + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn),
94+
block_shape=(BLOCK_DMODEL, BLOCK_N))
9695
O_desc = tl.make_tensor_descriptor(base=Out + qvk_offset, shape=(N_CTX, BLOCK_DMODEL),
9796
strides=(stride_om, stride_on), block_shape=(BLOCK_M, BLOCK_DMODEL))
9897
# initialize offsets
@@ -111,14 +110,14 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
111110
# For causal = True, STAGE = 3, the kernel gets 1 as its STAGE
112111
# For causal = False, STAGE = 1, the kernel gets 3 as its STAGE
113112
if STAGE & 1:
114-
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_desc, #
113+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_desc, V_desc, #
115114
start_m, qk_scale, #
116115
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
117116
4 - STAGE, offs_m, offs_n, N_CTX #
118117
)
119118
# stage 2: on-band
120119
if STAGE & 2:
121-
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_desc, #
120+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_desc, V_desc, #
122121
start_m, qk_scale, #
123122
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
124123
2, offs_m, offs_n, N_CTX #

python/triton/language/semantic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .._C.libtriton import ir
88
from . import core as tl
99

10+
import triton
11+
1012
T = TypeVar('T')
1113

1214

@@ -1953,7 +1955,8 @@ def make_tensor_descriptor(
19531955
)
19541956

19551957
strides[-1] = tl._constexpr_to_value(strides[-1])
1956-
if strides[-1] != 1:
1958+
backend = triton.runtime.driver.active.get_current_target().backend
1959+
if backend != "xpu" and strides[-1] != 1:
19571960
raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}")
19581961

19591962
shape = [to_tensor(x, builder) for x in shape]

0 commit comments

Comments
 (0)