Skip to content

Commit d607951

Browse files
Add tensor descriptor implementation for Flex Decoding (#4961)
pytorch/pytorch@fc69c2b removes block pointer implementation. In order to get good performance on Intel, structural pointer representation (e.g., tensor descriptor) implementation is required. This PR adds a tensor descriptor implementation for Flex Decoding under `USE_TMA`. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent b335d39 commit d607951

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

scripts/patch-pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
3636

3737
# put your patch applies here
3838
apply_patch ./patch/flex_attn_143553.patch
39+
apply_patch ./patch/flex_decoding_tensor_desc.patch
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
2+
index 679caa9f09..6192275691 100644
3+
--- a/torch/_inductor/kernel/flex/flex_decoding.py
4+
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
5+
@@ -326,6 +326,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
6+
# Set default to False
7+
cur_kernel_options.setdefault("USE_TMA", False)
8+
9+
+ if torch.xpu.is_available():
10+
+ cur_kernel_options["USE_TMA"] = True
11+
+
12+
# Add ROCm-specific parameters if they exist in the config
13+
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
14+
if hasattr(conf, attrib):
15+
diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
16+
index f4e894d9b7..3fb3b2c5bd 100644
17+
--- a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
18+
+++ b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
19+
@@ -128,11 +128,28 @@
20+
# last valid block according to sparse mask
21+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
22+
23+
+ desc_k = None
24+
+ desc_v = None
25+
+ {%- if USE_TMA %}
26+
+ desc_k = tl.make_tensor_descriptor(
27+
+ base=K,
28+
+ shape=[KV_LEN, QK_HEAD_DIM],
29+
+ strides=[stride_kn, 1],
30+
+ block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
31+
+ )
32+
+
33+
+ desc_v = tl.make_tensor_descriptor(
34+
+ base=V,
35+
+ shape=[KV_LEN, V_HEAD_DIM],
36+
+ strides=[stride_vn, 1],
37+
+ block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
38+
+ )
39+
+ {%- endif %}
40+
offs_n = tl.arange(0, BLOCK_N) + off_n
41+
42+
acc, l_i, m_i = forward_inner(
43+
{{gen_argdefs()}},
44+
- q, K, V, None, None, Q_LEN, KV_LEN,
45+
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
46+
# accumulatd values
47+
acc, l_i, m_i,
48+
#offsets
49+
@@ -163,11 +180,29 @@
50+
# last valid block according to sparse mask
51+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
52+
53+
+ desc_k = None
54+
+ desc_v = None
55+
+ {%- if USE_TMA %}
56+
+ desc_k = tl.make_tensor_descriptor(
57+
+ base=K,
58+
+ shape=[KV_LEN, QK_HEAD_DIM],
59+
+ strides=[stride_kn, 1],
60+
+ block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
61+
+ )
62+
+
63+
+ desc_v = tl.make_tensor_descriptor(
64+
+ base=V,
65+
+ shape=[KV_LEN, V_HEAD_DIM],
66+
+ strides=[stride_vn, 1],
67+
+ block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
68+
+ )
69+
+ {%- endif %}
70+
+
71+
offs_n = tl.arange(0, BLOCK_N) + off_n
72+
73+
acc, l_i, m_i = forward_inner(
74+
{{gen_argdefs()}},
75+
- q, K, V, None, None, Q_LEN, KV_LEN,
76+
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
77+
# accumulatd values
78+
acc, l_i, m_i,
79+
#offsets

0 commit comments

Comments
 (0)