Skip to content

Commit e23b08f

Browse files
committed
Adds head size padding and comments sequence length padding
Improves memory alignment by ensuring head dimensions are padded to multiples of 8 for 16-bit memory allocations. Comments out sequence length padding implementation for future consideration, including corresponding mask and bias padding logic in both forward and backward passes.
1 parent 8a3bb04 commit e23b08f

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,20 @@ def forward(
241241
if return_softmax is None:
242242
return_softmax = False
243243

244+
# Padding to multiple of 8 for 16-bit memory allocations
244245
head_size_og = q.size(3)
245246
if head_size_og % 8 != 0:
246247
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
247248
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
248249
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
250+
# seqlen_k_og = k.shape[1]
251+
# if seqlen_k_og % 8 != 0:
252+
# k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8])
253+
# v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8])
254+
# if mask is not None:
255+
# mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False)
256+
# if bias is not None:
257+
# bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0)
249258

250259
out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward(
251260
q,
@@ -265,6 +274,7 @@ def forward(
265274
ctx.is_causal = is_causal
266275
ctx.softcap = softcap
267276
ctx.deterministic = deterministic
277+
# ctx.seqlen_k_og = seqlen_k_og
268278

269279
out = out_padded[..., :head_size_og]
270280

@@ -307,6 +317,11 @@ def backward(
307317
dk = dk[..., : dout.shape[-1]]
308318
dv = dv[..., : dout.shape[-1]]
309319

320+
# if ctx.seqlen_k_og % 8 != 0:
321+
# dk = dk[:, : ctx.seqlen_k_og, :, :]
322+
# dv = dv[:, : ctx.seqlen_k_og, :, :]
323+
# dbias = dbias[..., : ctx.seqlen_k_og]
324+
310325
return dq, dk, dv, None, dbias, None, None, None, None, None, None
311326

312327

0 commit comments

Comments
 (0)