Skip to content

Commit de5dc1f

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN][SDPA][Nested Tensor] add forward/backward caching support for cuDNN SDPA Nested tensor/varlen (pytorch#161434)
Don't recompile every time Pull Request resolved: pytorch#161434 Approved by: https://github.com/drisspg
1 parent 72e6717 commit de5dc1f

File tree

2 files changed

+99
-48
lines changed

2 files changed

+99
-48
lines changed

aten/src/ATen/native/cudnn/MHA.cpp

Lines changed: 98 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ namespace native {
146146

147147
namespace fe = cudnn_frontend;
148148

149-
#define MAX_MHA_DIM 4
149+
constexpr uint8_t MAX_MHA_DIM = 4;
150150

151151
// Whether we will use ragged offsets in the dense (non-nested) path
152152
// to avoid recompilation
@@ -238,7 +238,8 @@ void setMHAParams(
238238
const std::optional<Tensor>& attn_bias,
239239
double dropout_probability,
240240
bool is_causal,
241-
bool return_softmaxstats) {
241+
bool return_softmaxstats,
242+
bool is_nested) {
242243
memset(&params, 0, sizeof(MHAParams));
243244
params.device_id = at::cuda::current_device();
244245
params.dataType = fe::DataType_t::HALF;
@@ -255,23 +256,24 @@ void setMHAParams(
255256
params.is_causal = is_causal;
256257
params.return_softmaxstats = return_softmaxstats;
257258
params.has_attn_bias = attn_bias.has_value();
259+
// Expect 4D dense tensor, 3D nested case (THD)
258260
TORCH_INTERNAL_ASSERT(
259-
q.sizes().size() == MAX_MHA_DIM,
261+
q.sizes().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested),
260262
"Q tensor has unexpected number of dims, please report a bug to PyTorch.");
261263
TORCH_INTERNAL_ASSERT(
262-
q.strides().size() == MAX_MHA_DIM,
264+
q.strides().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested),
263265
"Q tensor has unexpected number of dims, please report a bug to PyTorch.");
264266
TORCH_INTERNAL_ASSERT(
265-
k.sizes().size() == MAX_MHA_DIM,
267+
k.sizes().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested),
266268
"K tensor has unexpected number of dims, please report a bug to PyTorch.");
267269
TORCH_INTERNAL_ASSERT(
268-
k.strides().size() == MAX_MHA_DIM,
270+
k.strides().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested),
269271
"K tensor has unexpected number of dims, please report a bug to PyTorch.");
270272
TORCH_INTERNAL_ASSERT(
271-
v.sizes().size() == MAX_MHA_DIM,
273+
v.sizes().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested),
272274
"V tensor has unexpected number of dims, please report a bug to PyTorch.");
273275
TORCH_INTERNAL_ASSERT(
274-
v.strides().size() == MAX_MHA_DIM,
276+
v.strides().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested),
275277
"V tensor has unexpected number of dims, please report a bug to PyTorch.");
276278
std::copy(q.sizes().begin(), q.sizes().end(), params.q_dim.begin());
277279
std::copy(q.strides().begin(), q.strides().end(), params.q_stride.begin());
@@ -320,7 +322,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
320322
const std::optional<Tensor>& attn_bias,
321323
double dropout_probability,
322324
bool is_causal,
323-
bool return_softmaxstats) {
325+
bool return_softmaxstats,
326+
bool is_nested) {
324327
setMHAParams(
325328
this->pod,
326329
b,
@@ -335,7 +338,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
335338
attn_bias,
336339
dropout_probability,
337340
is_causal,
338-
return_softmaxstats);
341+
return_softmaxstats,
342+
is_nested);
339343
}
340344
};
341345

@@ -1386,7 +1390,8 @@ void run_cudnn_SDP_fprop(
13861390
attn_bias,
13871391
dropout_probability,
13881392
is_causal,
1389-
return_softmaxstats);
1393+
return_softmaxstats,
1394+
false);
13901395
auto graph_ptr = getMHAGraphCache_().find(key);
13911396
std::shared_ptr<fe::graph::Graph> mha_graph;
13921397
if (graph_ptr) {
@@ -1484,30 +1489,53 @@ void run_cudnn_SDP_fprop_nestedtensor(
14841489
if (return_softmaxstats && !softmaxstats.defined()) {
14851490
softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat));
14861491
}
1487-
auto mha_graph = build_graph_nestedtensor(
1492+
1493+
auto key = MHACacheKeyWrapper(
14881494
b,
14891495
h_q,
1490-
h_k,
1491-
h_v,
1492-
s_q,
1493-
s_kv,
1496+
s_q, // max-seqlen-q
1497+
s_kv, // max-seqlen-kv
14941498
d_qk,
14951499
d_v,
1496-
scaling_factor,
1497-
return_softmaxstats,
1498-
is_causal,
1499-
dropout_probability,
1500-
cum_seqlen_q,
1501-
cum_seqlen_kv,
15021500
q,
15031501
k,
15041502
v,
15051503
attn_bias,
1506-
softmaxstats,
1507-
o,
1508-
dropoutseed,
1509-
dropoutoffset,
1510-
handle);
1504+
dropout_probability,
1505+
is_causal,
1506+
return_softmaxstats,
1507+
true);
1508+
auto graph_ptr = getMHAGraphCache_().find(key);
1509+
std::shared_ptr<fe::graph::Graph> mha_graph;
1510+
1511+
if (graph_ptr) {
1512+
mha_graph = *graph_ptr;
1513+
} else {
1514+
mha_graph = build_graph_nestedtensor(
1515+
b,
1516+
h_q,
1517+
h_k,
1518+
h_v,
1519+
s_q,
1520+
s_kv,
1521+
d_qk,
1522+
d_v,
1523+
scaling_factor,
1524+
return_softmaxstats,
1525+
is_causal,
1526+
dropout_probability,
1527+
cum_seqlen_q,
1528+
cum_seqlen_kv,
1529+
q,
1530+
k,
1531+
v,
1532+
attn_bias,
1533+
softmaxstats,
1534+
o,
1535+
dropoutseed,
1536+
dropoutoffset,
1537+
handle);
1538+
}
15111539
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
15121540
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
15131541
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
@@ -1636,7 +1664,8 @@ void run_cudnn_SDP_bprop(
16361664
attn_bias,
16371665
dropout_probability,
16381666
is_causal,
1639-
true);
1667+
true,
1668+
false);
16401669
auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key);
16411670
std::shared_ptr<fe::graph::Graph> mha_graph;
16421671
if (graph_backward_ptr) {
@@ -1761,33 +1790,55 @@ void run_cudnn_SDP_bprop_nestedtensor(
17611790

17621791
cudnnHandle_t handle = getCudnnHandle();
17631792

1764-
auto mha_graph = build_graph_backward_nestedtensor(
1793+
auto key = MHACacheKeyWrapper(
17651794
b,
17661795
h_q,
1767-
h_k,
1768-
h_v,
1769-
s_q,
1770-
s_kv,
1796+
s_q, // max-seqlen-q
1797+
s_kv, // max-seqlen-kv
17711798
d_qk,
17721799
d_v,
1773-
scaling_factor,
1774-
is_causal,
1775-
dropout_probability,
1776-
cum_seqlen_q,
1777-
cum_seqlen_kv,
17781800
q,
17791801
k,
17801802
v,
17811803
attn_bias,
1782-
o,
1783-
dO_,
1784-
softmaxstats,
1785-
dQ,
1786-
dK,
1787-
dV,
1788-
dropoutseed,
1789-
dropoutoffset,
1790-
handle);
1804+
dropout_probability,
1805+
is_causal,
1806+
true,
1807+
true);
1808+
auto graph_ptr = getMHAGraphCache_().find(key);
1809+
std::shared_ptr<fe::graph::Graph> mha_graph;
1810+
1811+
if (graph_ptr) {
1812+
mha_graph = *graph_ptr;
1813+
} else {
1814+
mha_graph = build_graph_backward_nestedtensor(
1815+
b,
1816+
h_q,
1817+
h_k,
1818+
h_v,
1819+
s_q,
1820+
s_kv,
1821+
d_qk,
1822+
d_v,
1823+
scaling_factor,
1824+
is_causal,
1825+
dropout_probability,
1826+
cum_seqlen_q,
1827+
cum_seqlen_kv,
1828+
q,
1829+
k,
1830+
v,
1831+
attn_bias,
1832+
o,
1833+
dO_,
1834+
softmaxstats,
1835+
dQ,
1836+
dK,
1837+
dV,
1838+
dropoutseed,
1839+
dropoutoffset,
1840+
handle);
1841+
}
17911842

17921843
std::unordered_map<int64_t, void*> variant_pack = {
17931844
// inputs

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
606606

607607
const auto dprop = at::cuda::getCurrentDeviceProperties();
608608
// Check that the input is nested
609-
if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {
609+
if (!(dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {
610610
if (debug) {
611611
TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0.");
612612
}

0 commit comments

Comments
 (0)