Skip to content

Commit c82f7dc

Browse files
committed
Removes sequence length padding logic
Eliminates unnecessary padding of key and value tensors to multiples of 128 in sequence length dimension. Removes associated context saving and gradient unpadding operations that are no longer needed without the sequence length padding. Simplifies the forward and backward pass implementation by removing conditional padding logic for masks and biases.
1 parent e69b1c7 commit c82f7dc

File tree

1 file changed

+0
-16
lines changed

1 file changed

+0
-16
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,6 @@ def forward(
227227
return_softmax: Optional[bool],
228228
is_grad_enabled: bool = True,
229229
):
230-
# q, k, v are expected to be of shape (batch_size, seqlen, num_heads, head_size)
231-
seqlen_k = k.shape[1]
232230
is_grad = is_grad_enabled and any(
233231
x.requires_grad for x in [q, k, v]
234232
)
@@ -249,14 +247,6 @@ def forward(
249247
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
250248
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
251249

252-
if seqlen_k % 128 != 0:
253-
k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128])
254-
v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128])
255-
if mask is not None:
256-
mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False)
257-
if bias is not None:
258-
bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=0.0)
259-
260250
out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward(
261251
q,
262252
k,
@@ -271,7 +261,6 @@ def forward(
271261

272262
if is_grad:
273263
ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse)
274-
ctx.seqlen_k = seqlen_k
275264
ctx.softmax_scale = softmax_scale
276265
ctx.is_causal = is_causal
277266
ctx.softcap = softcap
@@ -318,11 +307,6 @@ def backward(
318307
dk = dk[..., : dout.shape[-1]]
319308
dv = dv[..., : dout.shape[-1]]
320309

321-
if ctx.seqlen_k % 128 != 0:
322-
dk = dk[:, : ctx.seqlen_k, :, :]
323-
dv = dv[:, : ctx.seqlen_k, :, :]
324-
dbias = dbias[..., : ctx.seqlen_k]
325-
326310
return dq, dk, dv, None, dbias, None, None, None, None, None, None
327311

328312

0 commit comments

Comments
 (0)