|
| 1 | +diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py |
| 2 | +index 679caa9f09..6192275691 100644 |
| 3 | +--- a/torch/_inductor/kernel/flex/flex_decoding.py |
| 4 | ++++ b/torch/_inductor/kernel/flex/flex_decoding.py |
| 5 | +@@ -326,6 +326,9 @@ def create_flex_decoding_kernel(*args, **kwargs): |
| 6 | + # Set default to False |
| 7 | + cur_kernel_options.setdefault("USE_TMA", False) |
| 8 | + |
| 9 | ++ if torch.xpu.is_available(): |
| 10 | ++ cur_kernel_options["USE_TMA"] = True |
| 11 | ++ |
| 12 | + # Add ROCm-specific parameters if they exist in the config |
| 13 | + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: |
| 14 | + if hasattr(conf, attrib): |
| 15 | +diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja |
| 16 | +index f4e894d9b7..3fb3b2c5bd 100644 |
| 17 | +--- a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja |
| 18 | ++++ b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja |
| 19 | +@@ -128,11 +128,28 @@ |
| 20 | + # last valid block according to sparse mask |
| 21 | + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) |
| 22 | + |
| 23 | ++ desc_k = None |
| 24 | ++ desc_v = None |
| 25 | ++ {%- if USE_TMA %} |
| 26 | ++ desc_k = tl.make_tensor_descriptor( |
| 27 | ++ base=K, |
| 28 | ++ shape=[KV_LEN, QK_HEAD_DIM], |
| 29 | ++ strides=[stride_kn, 1], |
| 30 | ++ block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], |
| 31 | ++ ) |
| 32 | ++ |
| 33 | ++ desc_v = tl.make_tensor_descriptor( |
| 34 | ++ base=V, |
| 35 | ++ shape=[KV_LEN, V_HEAD_DIM], |
| 36 | ++ strides=[stride_vn, 1], |
| 37 | ++ block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], |
| 38 | ++ ) |
| 39 | ++ {%- endif %} |
| 40 | + offs_n = tl.arange(0, BLOCK_N) + off_n |
| 41 | + |
| 42 | + acc, l_i, m_i = forward_inner( |
| 43 | + {{gen_argdefs()}}, |
| 44 | +- q, K, V, None, None, Q_LEN, KV_LEN, |
| 45 | ++ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, |
| 46 | + # accumulatd values |
| 47 | + acc, l_i, m_i, |
| 48 | + #offsets |
| 49 | +@@ -163,11 +180,29 @@ |
| 50 | + # last valid block according to sparse mask |
| 51 | + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) |
| 52 | + |
| 53 | ++ desc_k = None |
| 54 | ++ desc_v = None |
| 55 | ++ {%- if USE_TMA %} |
| 56 | ++ desc_k = tl.make_tensor_descriptor( |
| 57 | ++ base=K, |
| 58 | ++ shape=[KV_LEN, QK_HEAD_DIM], |
| 59 | ++ strides=[stride_kn, 1], |
| 60 | ++ block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], |
| 61 | ++ ) |
| 62 | ++ |
| 63 | ++ desc_v = tl.make_tensor_descriptor( |
| 64 | ++ base=V, |
| 65 | ++ shape=[KV_LEN, V_HEAD_DIM], |
| 66 | ++ strides=[stride_vn, 1], |
| 67 | ++ block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], |
| 68 | ++ ) |
| 69 | ++ {%- endif %} |
| 70 | ++ |
| 71 | + offs_n = tl.arange(0, BLOCK_N) + off_n |
| 72 | + |
| 73 | + acc, l_i, m_i = forward_inner( |
| 74 | + {{gen_argdefs()}}, |
| 75 | +- q, K, V, None, None, Q_LEN, KV_LEN, |
| 76 | ++ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, |
| 77 | + # accumulatd values |
| 78 | + acc, l_i, m_i, |
| 79 | + #offsets |
0 commit comments