@@ -6,6 +6,7 @@ template <
6
6
typename Element,
7
7
typename ActiveMask,
8
8
bool kIsVarlen ,
9
+ bool kIsDeterministic ,
9
10
class ... KernelOptions>
10
11
std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd (
11
12
const at::Tensor& dO,
@@ -36,7 +37,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
36
37
using TileShape = Shape<_128, _128, _128>;
37
38
38
39
using Operation = cutlass::fmha::device::
39
- Sm100FmhaBwd<ProblemShapeType, Element, ElementAccumulator, TileShape, /* kIsMla=*/ false , ActiveMask>;
40
+ Sm100FmhaBwd<ProblemShapeType, Element, ElementAccumulator, TileShape, /* kIsMla=*/ false , ActiveMask, kIsDeterministic >;
40
41
41
42
using StrideQ = Stride<int , _1, Stride<Stride<int , int >, int >>; // Q D ((H_R, H_K), B)
42
43
using StrideK = Stride<int , _1, Stride<Stride<_0, int >, int >>; // K D ((H_R, H_K), B)
@@ -219,6 +220,19 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
219
220
cutlass::KernelHardwareInfo::query_device_multiprocessor_count (
220
221
hw_info.device_id );
221
222
223
+ auto seqlen_q = kIsVarlen ? max_seq_len_q.value () : q.size (1 );
224
+
225
+ int * dq_semaphore_ptr = nullptr ;
226
+ at::Tensor dq_semaphore;
227
+ if (kIsDeterministic ) {
228
+ auto kBlockM = cute::get<0 >(TileShape{});
229
+ auto opts = q.options ();
230
+ dq_semaphore = torch::zeros (
231
+ {(seqlen_q + kBlockM - 1 ) / kBlockM , B, H_Q},
232
+ opts.dtype (torch::kInt32 ));
233
+ dq_semaphore_ptr = static_cast <int *>(dq_semaphore.data_ptr ());
234
+ }
235
+
222
236
typename Operation::Arguments arguments{
223
237
problem_shape,
224
238
static_cast <Element*>(q.data_ptr ()),
@@ -240,6 +254,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
240
254
static_cast <Element*>(dV.data_ptr ()),
241
255
stride_dV,
242
256
softmax_scale,
257
+ dq_semaphore_ptr,
243
258
window_size_left,
244
259
window_size_right,
245
260
hw_info};
@@ -264,7 +279,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
264
279
bool causal,
265
280
int64_t window_size_left,
266
281
int64_t window_size_right,
267
- bool bottom_right
282
+ bool bottom_right,
283
+ bool deterministic
268
284
) {
269
285
// This workaround initializes the CUDA context to prevent the 201 error
270
286
// (invalid context). When this function is invoked through PyTorch
@@ -294,11 +310,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
294
310
}
295
311
296
312
auto dispatch_fmha =
297
- [&](auto element, auto element_out, auto varlen, auto mask, auto ... kernel_options) {
313
+ [&](
314
+ auto element,
315
+ auto element_out,
316
+ auto varlen,
317
+ auto deterministic,
318
+ auto mask,
319
+ auto ... kernel_options) {
298
320
return fmha_bwd<
299
321
decltype (element),
300
322
decltype (mask),
301
323
varlen,
324
+ deterministic,
302
325
decltype (kernel_options)...>
303
326
(
304
327
dOutput,
@@ -315,53 +338,69 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
315
338
window_size_right);
316
339
};
317
340
318
- auto dispatch_type = [&](auto varlen, auto mask) {
341
+ auto dispatch_type = [&](auto varlen, auto deterministic, auto mask) {
319
342
if (query.dtype () == torch::kFloat16 ) {
320
- return dispatch_fmha (cutlass::half_t {}, cutlass::half_t {}, varlen, mask);
343
+ return dispatch_fmha (
344
+ cutlass::half_t {}, cutlass::half_t {}, varlen, deterministic, mask);
321
345
}
322
346
else if (query.dtype () == torch::kBFloat16 ) {
323
347
return dispatch_fmha (
324
- cutlass::bfloat16_t {}, cutlass::bfloat16_t {}, varlen, mask);
348
+ cutlass::bfloat16_t {}, cutlass::bfloat16_t {}, varlen, deterministic, mask);
325
349
}
326
350
else if (query.dtype () == torch::kFloat8_e4m3fn ) {
327
351
return dispatch_fmha (
328
- cutlass::float_e4m3_t {}, cutlass::bfloat16_t {}, varlen, mask);
352
+ cutlass::float_e4m3_t {}, cutlass::bfloat16_t {}, varlen, deterministic, mask);
329
353
}
330
354
TORCH_CHECK (false , " Unsupported dtype for q: " , query.dtype ());
331
355
};
332
356
333
- auto dispatch_mask = [&](auto varlen) {
357
+ auto dispatch_mask = [&](auto varlen, auto deterministic ) {
334
358
if (causal) {
335
359
if (bottom_right) {
336
- return dispatch_type (varlen, CausalForBackwardMask</* kIsQBegin=*/ false >{});
360
+ return dispatch_type (
361
+ varlen, deterministic, CausalForBackwardMask</* kIsQBegin=*/ false >{});
337
362
}
338
363
else {
339
- return dispatch_type (varlen, CausalForBackwardMask</* kIsQBegin=*/ true >{});
364
+ return dispatch_type (
365
+ varlen, deterministic, CausalForBackwardMask</* kIsQBegin=*/ true >{});
340
366
}
341
367
}
342
368
else if (local) {
343
369
if (bottom_right) {
344
- return dispatch_type (varlen, LocalMaskForBackward</* kIsQBegin=*/ false >{});
370
+ return dispatch_type (
371
+ varlen, deterministic, LocalMaskForBackward</* kIsQBegin=*/ false >{});
345
372
}
346
373
else {
347
- return dispatch_type (varlen, LocalMaskForBackward</* kIsQBegin=*/ true >{});
374
+ return dispatch_type (
375
+ varlen, deterministic, LocalMaskForBackward</* kIsQBegin=*/ true >{});
348
376
}
349
377
}
350
378
else if (varlen || key.size (1 ) % 128 != 0 ) {
351
379
// Use the residual mask for varlen or when K seqlen is not multiple of
352
380
// blockN
353
- return dispatch_type (varlen, ResidualMaskForBackward{});
381
+ return dispatch_type (
382
+ varlen, deterministic, ResidualMaskForBackward{});
383
+ }
384
+ else {
385
+ return dispatch_type (
386
+ varlen, deterministic, NoMask{});
387
+ }
388
+ };
389
+
390
+ auto dispatch_deterministic = [&](auto varlen) {
391
+ if (deterministic) {
392
+ return dispatch_mask (varlen, std::bool_constant<true >{});
354
393
}
355
394
else {
356
- return dispatch_type (varlen, NoMask {});
395
+ return dispatch_mask (varlen, std::bool_constant< false > {});
357
396
}
358
397
};
359
398
360
399
if (max_seq_len_q.has_value ()) {
361
- return dispatch_mask (std::bool_constant<true >{});
400
+ return dispatch_deterministic (std::bool_constant<true >{});
362
401
} else {
363
402
TORCH_CHECK (query.dim () == 4 , " q must be [B, M, H, D] for fixed length" )
364
- return dispatch_mask (std::bool_constant<false >{});
403
+ return dispatch_deterministic (std::bool_constant<false >{});
365
404
}
366
405
}
367
406
@@ -383,7 +422,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
383
422
" bool causal=False, "
384
423
" int window_size_left=-1, "
385
424
" int window_size_right=-1, "
386
- " bool bottom_right=True"
425
+ " bool bottom_right=True, "
426
+ " bool deterministic=False"
387
427
" ) -> (Tensor, Tensor, Tensor)"
388
428
);
389
429
}
0 commit comments