Skip to content

Commit 1a93ffe

Browse files
sryapfacebook-github-bot
authored andcommitted
Enable deterministic mode in Cutlass attn (#4840)
Summary: Pull Request resolved: #4840 X-link: facebookresearch/FBGEMM#1865 Enable deterministic mode in Cutlass Blackwell Attention. Previously, all parts of the kernels were deterministic except for the dQ computation in the backward pass, which performed reduction across thread blocks using atomic add. This diff serializes the dQ aggregation across thread blocks, ensuring full determinism throughout the kernel execution. Changes include: - Rebased D79426672 onto master - Enabled deterministic mode for sliding window attention - Enabled deterministic mode for GQA in backward - Added unit tests Reviewed By: y-sq, jianyuh Differential Revision: D81980925 fbshipit-source-id: 1fac8921d166e56692111c8a844a559c6036c7e2
1 parent e0b24f6 commit 1a93ffe

File tree

6 files changed

+244
-63
lines changed

6 files changed

+244
-63
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def _cutlass_blackwell_fmha_backward(
104104
window_left: int = -1,
105105
window_right: int = -1,
106106
bottom_right: bool = True,
107+
deterministic: bool = False,
107108
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109+
deterministic = deterministic or torch.are_deterministic_algorithms_enabled()
108110
dout = maybe_contiguous(dout)
109111
q = maybe_contiguous(q)
110112
k = maybe_contiguous(k)
@@ -125,6 +127,7 @@ def _cutlass_blackwell_fmha_backward(
125127
window_size_left=window_left,
126128
window_size_right=window_right,
127129
bottom_right=bottom_right,
130+
deterministic=deterministic,
128131
)
129132

130133

@@ -172,6 +175,7 @@ def forward( # type: ignore
172175
seqlen_kv: Optional[torch.Tensor] = None,
173176
window_size: Tuple[int, int] = (-1, -1),
174177
bottom_right: bool = True,
178+
deterministic: bool = False,
175179
) -> torch.Tensor:
176180
# Check if this is generation phase (sq = 1)
177181
sq = q.shape[1]
@@ -232,6 +236,7 @@ def forward( # type: ignore
232236
ctx.cu_seqlens_k = cu_seqlens_k
233237
ctx.is_gen = False
234238
ctx.bottom_right = bottom_right
239+
ctx.deterministic = deterministic
235240
return out
236241

237242
@staticmethod
@@ -248,6 +253,7 @@ def backward(ctx, dout: torch.Tensor, *args: Any) -> Tuple[ # type: ignore
248253
None,
249254
None,
250255
None,
256+
None,
251257
]:
252258
if ctx.is_gen:
253259
# For gen case, no backward pass is needed (generation is inference only)
@@ -272,8 +278,9 @@ def backward(ctx, dout: torch.Tensor, *args: Any) -> Tuple[ # type: ignore
272278
window_left,
273279
window_right,
274280
bottom_right=ctx.bottom_right,
281+
deterministic=ctx.deterministic,
275282
)
276-
return dq, dk, dv, None, None, None, None, None, None, None, None, None
283+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
277284

278285

279286
def cutlass_blackwell_fmha_func(
@@ -289,6 +296,7 @@ def cutlass_blackwell_fmha_func(
289296
seqlen_kv: torch.Tensor | None = None,
290297
window_size: tuple[int, int] | None = (-1, -1),
291298
bottom_right: bool = True,
299+
deterministic: bool = False,
292300
):
293301
return CutlassBlackwellFmhaFunc.apply(
294302
q,
@@ -303,4 +311,5 @@ def cutlass_blackwell_fmha_func(
303311
seqlen_kv,
304312
window_size,
305313
bottom_right,
314+
deterministic,
306315
)

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ template <
66
typename Element,
77
typename ActiveMask,
88
bool kIsVarlen,
9+
bool kIsDeterministic,
910
class... KernelOptions>
1011
std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
1112
const at::Tensor& dO,
@@ -36,7 +37,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
3637
using TileShape = Shape<_128, _128, _128>;
3738

3839
using Operation = cutlass::fmha::device::
39-
Sm100FmhaBwd<ProblemShapeType, Element, ElementAccumulator, TileShape, /*kIsMla=*/false, ActiveMask>;
40+
Sm100FmhaBwd<ProblemShapeType, Element, ElementAccumulator, TileShape, /*kIsMla=*/false, ActiveMask, kIsDeterministic>;
4041

4142
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
4243
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
@@ -219,6 +220,19 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
219220
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
220221
hw_info.device_id);
221222

223+
auto seqlen_q = kIsVarlen ? max_seq_len_q.value() : q.size(1);
224+
225+
int* dq_semaphore_ptr = nullptr;
226+
at::Tensor dq_semaphore;
227+
if (kIsDeterministic) {
228+
auto kBlockM = cute::get<0>(TileShape{});
229+
auto opts = q.options();
230+
dq_semaphore = torch::zeros(
231+
{(seqlen_q + kBlockM - 1) / kBlockM, B, H_Q},
232+
opts.dtype(torch::kInt32));
233+
dq_semaphore_ptr = static_cast<int*>(dq_semaphore.data_ptr());
234+
}
235+
222236
typename Operation::Arguments arguments{
223237
problem_shape,
224238
static_cast<Element*>(q.data_ptr()),
@@ -240,6 +254,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
240254
static_cast<Element*>(dV.data_ptr()),
241255
stride_dV,
242256
softmax_scale,
257+
dq_semaphore_ptr,
243258
window_size_left,
244259
window_size_right,
245260
hw_info};
@@ -264,7 +279,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
264279
bool causal,
265280
int64_t window_size_left,
266281
int64_t window_size_right,
267-
bool bottom_right
282+
bool bottom_right,
283+
bool deterministic
268284
) {
269285
// This workaround initializes the CUDA context to prevent the 201 error
270286
// (invalid context). When this function is invoked through PyTorch
@@ -294,11 +310,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
294310
}
295311

296312
auto dispatch_fmha =
297-
[&](auto element, auto element_out, auto varlen, auto mask, auto... kernel_options) {
313+
[&](
314+
auto element,
315+
auto element_out,
316+
auto varlen,
317+
auto deterministic,
318+
auto mask,
319+
auto... kernel_options) {
298320
return fmha_bwd<
299321
decltype(element),
300322
decltype(mask),
301323
varlen,
324+
deterministic,
302325
decltype(kernel_options)...>
303326
(
304327
dOutput,
@@ -315,53 +338,69 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
315338
window_size_right);
316339
};
317340

318-
auto dispatch_type = [&](auto varlen, auto mask) {
341+
auto dispatch_type = [&](auto varlen, auto deterministic, auto mask) {
319342
if (query.dtype() == torch::kFloat16) {
320-
return dispatch_fmha(cutlass::half_t{}, cutlass::half_t{}, varlen, mask);
343+
return dispatch_fmha(
344+
cutlass::half_t{}, cutlass::half_t{}, varlen, deterministic, mask);
321345
}
322346
else if (query.dtype() == torch::kBFloat16) {
323347
return dispatch_fmha(
324-
cutlass::bfloat16_t{}, cutlass::bfloat16_t{}, varlen, mask);
348+
cutlass::bfloat16_t{}, cutlass::bfloat16_t{}, varlen, deterministic, mask);
325349
}
326350
else if (query.dtype() == torch::kFloat8_e4m3fn) {
327351
return dispatch_fmha(
328-
cutlass::float_e4m3_t{}, cutlass::bfloat16_t{}, varlen, mask);
352+
cutlass::float_e4m3_t{}, cutlass::bfloat16_t{}, varlen, deterministic, mask);
329353
}
330354
TORCH_CHECK(false, "Unsupported dtype for q: ", query.dtype());
331355
};
332356

333-
auto dispatch_mask = [&](auto varlen) {
357+
auto dispatch_mask = [&](auto varlen, auto deterministic) {
334358
if (causal) {
335359
if (bottom_right) {
336-
return dispatch_type(varlen, CausalForBackwardMask</*kIsQBegin=*/false>{});
360+
return dispatch_type(
361+
varlen, deterministic, CausalForBackwardMask</*kIsQBegin=*/false>{});
337362
}
338363
else {
339-
return dispatch_type(varlen, CausalForBackwardMask</*kIsQBegin=*/true>{});
364+
return dispatch_type(
365+
varlen, deterministic, CausalForBackwardMask</*kIsQBegin=*/true>{});
340366
}
341367
}
342368
else if (local) {
343369
if (bottom_right) {
344-
return dispatch_type(varlen, LocalMaskForBackward</*kIsQBegin=*/false>{});
370+
return dispatch_type(
371+
varlen, deterministic, LocalMaskForBackward</*kIsQBegin=*/false>{});
345372
}
346373
else {
347-
return dispatch_type(varlen, LocalMaskForBackward</*kIsQBegin=*/true>{});
374+
return dispatch_type(
375+
varlen, deterministic, LocalMaskForBackward</*kIsQBegin=*/true>{});
348376
}
349377
}
350378
else if (varlen || key.size(1) % 128 != 0) {
351379
// Use the residual mask for varlen or when K seqlen is not multiple of
352380
// blockN
353-
return dispatch_type(varlen, ResidualMaskForBackward{});
381+
return dispatch_type(
382+
varlen, deterministic, ResidualMaskForBackward{});
383+
}
384+
else {
385+
return dispatch_type(
386+
varlen, deterministic, NoMask{});
387+
}
388+
};
389+
390+
auto dispatch_deterministic = [&](auto varlen) {
391+
if (deterministic) {
392+
return dispatch_mask(varlen, std::bool_constant<true>{});
354393
}
355394
else {
356-
return dispatch_type(varlen, NoMask{});
395+
return dispatch_mask(varlen, std::bool_constant<false>{});
357396
}
358397
};
359398

360399
if (max_seq_len_q.has_value()) {
361-
return dispatch_mask(std::bool_constant<true>{});
400+
return dispatch_deterministic(std::bool_constant<true>{});
362401
} else {
363402
TORCH_CHECK(query.dim() == 4, "q must be [B, M, H, D] for fixed length")
364-
return dispatch_mask(std::bool_constant<false>{});
403+
return dispatch_deterministic(std::bool_constant<false>{});
365404
}
366405
}
367406

@@ -383,7 +422,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
383422
" bool causal=False, "
384423
" int window_size_left=-1, "
385424
" int window_size_right=-1, "
386-
" bool bottom_right=True"
425+
" bool bottom_right=True, "
426+
" bool deterministic=False"
387427
") -> (Tensor, Tensor, Tensor)"
388428
);
389429
}

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ template<
5858
class ElementAccumulator,
5959
class TileShape,
6060
bool IsMla,
61-
class Mask
61+
class Mask,
62+
bool IsDeterministic=false
6263
>
6364
class Sm100FmhaBwd {
6465
private:
@@ -123,6 +124,8 @@ class Sm100FmhaBwd {
123124

124125
ElementAccumulator softmax_scale;
125126

127+
int* ptr_dq_semaphore;
128+
126129
int window_size_left = -1;
127130
int window_size_right = -1;
128131

@@ -138,7 +141,7 @@ class Sm100FmhaBwd {
138141

139142
using OperationMha = cutlass::fmha::device::FMHA<
140143
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
141-
ProblemShape, Element, ElementAccumulator, TileShape, Mask
144+
ProblemShape, Element, ElementAccumulator, TileShape, Mask, IsDeterministic
142145
>
143146
>;
144147

@@ -223,7 +226,10 @@ class Sm100FmhaBwd {
223226
scaled_lse, to_bwd_stride(stride_scaled_lse),
224227
sum_OdO, to_bwd_stride(stride_sum_OdO),
225228
dQ_acc, to_bwd_stride(stride_dQ),
226-
args.softmax_scale, args.window_size_left, args.window_size_right},
229+
args.softmax_scale,
230+
args.ptr_dq_semaphore,
231+
args.window_size_left,
232+
args.window_size_right },
227233
{ args.ptr_dK, to_bwd_stride(args.stride_dK),
228234
args.ptr_dV, to_bwd_stride(args.stride_dV) },
229235
args.hw_info

0 commit comments

Comments
 (0)