Skip to content

Commit ef0483d

Browse files
Revert "Ensure large tensor int32 -> int64 indexing is enabled (pytorch#157767)"
This reverts commit b36a20d. Reverted pytorch#157767 on behalf of https://github.com/atalman due to need to revert pytorch#157767 internal tests ([comment](pytorch#157767 (comment)))
1 parent 5432966 commit ef0483d

File tree

6 files changed

+120
-100
lines changed

6 files changed

+120
-100
lines changed

test/inductor/test_flex_attention.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
skipCPUIf,
4949
skipCUDAIf,
5050
)
51-
from torch.testing._internal.common_utils import IS_FBCODE
5251
from torch.utils._triton import has_triton, has_triton_tma_device
5352

5453

@@ -4340,41 +4339,6 @@ def simple_score_mod(score, b, h, q_idx, kv_idx):
43404339
fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag
43414340
fa._WARNINGS_SHOWN = original_warnings_shown
43424341

4343-
@largeTensorTest("38GB", "cuda") # emperically
4344-
@skip_on_cpu
4345-
@unittest.skipIf(IS_FBCODE, "Skip large tensor test in fbcode")
4346-
def test_int64_indexing_large_stride(self, device):
4347-
B = 1
4348-
H = 64
4349-
S = 2**20
4350-
D = 64
4351-
dtype = torch.float16
4352-
4353-
def _simple_causal(b, h, q_idx, kv_idx):
4354-
return q_idx >= kv_idx
4355-
4356-
BLOCK_M = 1024
4357-
BLOCK_N = 1024
4358-
4359-
block_mask = torch.compile(create_block_mask)(
4360-
_simple_causal, B, H, S, S, device=device, BLOCK_SIZE=(BLOCK_M, BLOCK_N)
4361-
)
4362-
4363-
q = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4364-
k = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4365-
v = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
4366-
4367-
# Test forward and backward pass
4368-
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
4369-
loss = out.sum()
4370-
loss.backward()
4371-
4372-
# Basic correctness checks, doing full comapre consumes too much memory :/
4373-
self.assertEqual(out.shape, (B, H, S, D))
4374-
self.assertTrue(q.grad is not None)
4375-
self.assertTrue(k.grad is not None)
4376-
self.assertTrue(v.grad is not None)
4377-
43784342

43794343
class TestBlockMask(InductorTestCase):
43804344
def setUp(self):

torch/_inductor/kernel/flex/templates/common.py.jinja

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
@triton.jit
55
def forward_block_mn(
66
{{gen_argdefs()}},
7-
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
7+
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
88
# accumulated values
99
acc, l_i, m_i,
1010
# Offsets
@@ -13,8 +13,6 @@ def forward_block_mn(
1313
kv_start,
1414
kv_offset,
1515
MATMUL_PRECISION, RCP_LN2,
16-
# Strides for K and V
17-
stride_kk, stride_kn, stride_vn, stride_vk,
1816
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
1917

2018
):
@@ -23,21 +21,17 @@ def forward_block_mn(
2321

2422
# -- load k --
2523
# NB reversed order to since K is transposed
26-
kv_base_offset = kv_start + kv_offset
2724
{%- if USE_TMA %}
2825
k = tl.load_tensor_descriptor(
2926
desc_k,
30-
[kv_base_offset, 0],
27+
[kv_start + kv_offset, 0],
3128
)
3229
{%- else %}
33-
34-
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
35-
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
36-
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
37-
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
30+
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
3831
{%- endif %}
3932

40-
k = tl.trans(k)
33+
if USE_TMA:
34+
k = tl.trans(k)
4135
# -- compute qk ---
4236
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
4337
if not PRESCALE_QK:
@@ -104,12 +98,10 @@ def forward_block_mn(
10498
{%- if USE_TMA %}
10599
v = tl.load_tensor_descriptor(
106100
desc_v,
107-
[kv_base_offset, 0],
101+
[kv_start + kv_offset, 0],
108102
)
109103
{%- else %}
110-
# Calculate offsets for V loading - reuse kv_base_offset from K loading
111-
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
112-
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
104+
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
113105
{%- endif %}
114106
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
115107

@@ -121,7 +113,7 @@ def forward_block_mn(
121113
@triton.jit
122114
def forward_inner(
123115
{{gen_argdefs()}},
124-
q, K, V,
116+
q, K_block_ptr, V_block_ptr,
125117
desc_k, desc_v, Q_LEN, KV_LEN,
126118
# accumulated values
127119
acc, l_i, m_i,
@@ -135,8 +127,6 @@ def forward_inner(
135127
# start kv and end kv block
136128
block_n_start, block_n_end,
137129
MATMUL_PRECISION,
138-
# Strides for K and V
139-
stride_kk, stride_kn, stride_vn, stride_vk,
140130
IS_FULL_BLOCKS,
141131
):
142132
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
@@ -156,7 +146,7 @@ def forward_inner(
156146
if IS_DIVISIBLE:
157147
acc, l_i, m_i = forward_block_mn(
158148
{{gen_argdefs()}},
159-
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
149+
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
160150
# accumulated values
161151
acc, l_i, m_i,
162152
# Offsets
@@ -165,8 +155,6 @@ def forward_inner(
165155
kv_start,
166156
kv_offset,
167157
MATMUL_PRECISION, RCP_LN2,
168-
# Strides for K and V
169-
stride_kk, stride_kn, stride_vn, stride_vk,
170158
IS_FULL_BLOCKS,
171159
)
172160
else:
@@ -176,7 +164,7 @@ def forward_inner(
176164
# to the last block because it's faster a lot.
177165
acc, l_i, m_i = forward_block_mn(
178166
{{gen_argdefs()}},
179-
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
167+
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
180168
# accumulated values
181169
acc, l_i, m_i,
182170
# Offsets
@@ -185,8 +173,6 @@ def forward_inner(
185173
kv_start,
186174
kv_offset,
187175
MATMUL_PRECISION, RCP_LN2,
188-
# Strides for K and V
189-
stride_kk, stride_kn, stride_vn, stride_vk,
190176
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
191177
)
192178

@@ -199,6 +185,9 @@ def forward_inner(
199185

200186
offs_n = offs_n + offset
201187
kv_offset += offset
188+
if not USE_TMA:
189+
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
190+
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
202191

203192

204193
return acc, l_i, m_i

torch/_inductor/kernel/flex/templates/flex_attention.py.jinja

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@
4545

4646
MATMUL_PRECISION = Q.dtype.element_ty
4747

48-
q_start = tl.program_id(0).to(INDEX_DTYPE)
49-
off_zq = tl.program_id(1).to(INDEX_DTYPE)
50-
off_hq = tl.program_id(2).to(INDEX_DTYPE)
48+
q_start = tl.program_id(0)
49+
off_zq = tl.program_id(1)
50+
off_hq = tl.program_id(2)
5151

5252
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
5353
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
@@ -114,16 +114,27 @@
114114
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
115115
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
116116
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
117+
K_block_ptr = None
118+
V_block_ptr = None
119+
Q_block_ptr = None
120+
121+
if not USE_TMA:
122+
Q_block_ptr = tl.make_block_ptr(
123+
base=Q ,
124+
shape=(Q_LEN, QK_HEAD_DIM),
125+
strides=(stride_qm, stride_qk),
126+
offsets=(q_start * BLOCK_M, 0),
127+
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
128+
order=(1, 0)
129+
)
117130

118131
{%- if USE_TMA %}
119132
q = tl.load_tensor_descriptor(
120133
desc_q,
121134
[(q_start * BLOCK_M).to(tl.int32), 0],
122135
)
123136
{%- else %}
124-
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
125-
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
126-
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
137+
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
127138
{%- endif %}
128139

129140
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -135,22 +146,38 @@
135146
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
136147

137148

138-
# K and V pointers will be passed directly to forward_inner
149+
if not USE_TMA:
150+
K_block_ptr = tl.make_block_ptr(
151+
base=K,
152+
shape=(QK_HEAD_DIM, KV_LEN),
153+
strides=(stride_kk, stride_kn),
154+
offsets=(0, kv_start),
155+
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
156+
order=(0, 1)
157+
)
158+
159+
V_block_ptr = tl.make_block_ptr(
160+
base=V,
161+
shape=(KV_LEN, V_HEAD_DIM),
162+
strides=(stride_vn, stride_vk),
163+
offsets=(kv_start, 0),
164+
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
165+
order=(1, 0)
166+
)
139167

140168
offs_n = kv_start + tl.arange(0, BLOCK_N)
141169

142170

143171
acc, l_i, m_i = forward_inner(
144172
{{gen_argdefs()}},
145-
q, K, V,
173+
q, K_block_ptr, V_block_ptr,
146174
desc_k, desc_v, Q_LEN, KV_LEN,
147175
acc, l_i, m_i,
148176
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
149177
kv_start,
150178
kv_indices, kv_num_blocks,
151179
0, block_n_end,
152180
MATMUL_PRECISION,
153-
stride_kk, stride_kn, stride_vn, stride_vk,
154181
IS_FULL_BLOCKS=False,
155182
)
156183

@@ -163,20 +190,35 @@
163190
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
164191
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
165192
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
166-
# K and V pointers will be passed directly to forward_inner
193+
if not USE_TMA:
194+
K_block_ptr = tl.make_block_ptr(
195+
base=K,
196+
shape=(QK_HEAD_DIM, KV_LEN),
197+
strides=(stride_kk, stride_kn),
198+
offsets=(0, kv_start),
199+
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
200+
order=(0, 1)
201+
)
202+
V_block_ptr = tl.make_block_ptr(
203+
base=V,
204+
shape=(KV_LEN, V_HEAD_DIM),
205+
strides=(stride_vn, stride_vk),
206+
offsets=(kv_start, 0),
207+
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
208+
order=(1, 0)
209+
)
167210
offs_n = kv_start + tl.arange(0, BLOCK_N)
168211

169212
acc, l_i, m_i = forward_inner(
170213
{{gen_argdefs()}},
171-
q, K, V,
214+
q, K_block_ptr, V_block_ptr,
172215
desc_k, desc_v, Q_LEN, KV_LEN,
173216
acc, l_i, m_i,
174217
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
175218
kv_start,
176219
kv_indices, kv_num_blocks,
177220
0, block_n_end,
178221
MATMUL_PRECISION,
179-
stride_kk, stride_kn, stride_vn, stride_vk,
180222
IS_FULL_BLOCKS=True,
181223
)
182224

@@ -187,10 +229,10 @@
187229
l_i = tl.where(l_i == 0.0, 1, l_i)
188230

189231
acc = acc / l_i[:, None]
190-
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
191-
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
192-
idx_m = offs_m[:, None].to(INDEX_DTYPE)
193-
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
232+
idx_zq = tl.program_id(1)
233+
idx_hq = tl.program_id(2)
234+
idx_m = offs_m[:, None]
235+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
194236

195237
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
196238

torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@
5151

5252
MATMUL_PRECISION = Q.dtype.element_ty
5353

54-
pid = tl.program_id(0).to(INDEX_DTYPE)
54+
pid = tl.program_id(0)
5555
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
5656
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
5757

58-
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
59-
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
58+
off_zq = tl.program_id(1) # q batch idx
59+
off_hkv = tl.program_id(2) # kv head idx
6060
off_zkv = off_zq % ZKV # kv batch idx
6161

6262
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}

0 commit comments

Comments
 (0)