@@ -146,7 +146,7 @@ namespace native {
146146
147147namespace fe = cudnn_frontend;
148148
149- # define MAX_MHA_DIM 4
149+ constexpr uint8_t MAX_MHA_DIM = 4 ;
150150
151151// Whether we will use ragged offsets in the dense (non-nested) path
152152// to avoid recompilation
@@ -238,7 +238,8 @@ void setMHAParams(
238238 const std::optional<Tensor>& attn_bias,
239239 double dropout_probability,
240240 bool is_causal,
241- bool return_softmaxstats) {
241+ bool return_softmaxstats,
242+ bool is_nested) {
242243 memset (¶ms, 0 , sizeof (MHAParams));
243244 params.device_id = at::cuda::current_device ();
244245 params.dataType = fe::DataType_t::HALF;
@@ -255,23 +256,24 @@ void setMHAParams(
255256 params.is_causal = is_causal;
256257 params.return_softmaxstats = return_softmaxstats;
257258 params.has_attn_bias = attn_bias.has_value ();
259+ // Expect 4D dense tensor, 3D nested case (THD)
258260 TORCH_INTERNAL_ASSERT (
259- q.sizes ().size () == MAX_MHA_DIM,
261+ q.sizes ().size () == ( uint8_t )( MAX_MHA_DIM - ( uint8_t )is_nested) ,
260262 " Q tensor has unexpected number of dims, please report a bug to PyTorch." );
261263 TORCH_INTERNAL_ASSERT (
262- q.strides ().size () == MAX_MHA_DIM,
264+ q.strides ().size () == ( uint8_t )( MAX_MHA_DIM - ( uint8_t )is_nested) ,
263265 " Q tensor has unexpected number of dims, please report a bug to PyTorch." );
264266 TORCH_INTERNAL_ASSERT (
265- k.sizes ().size () == MAX_MHA_DIM,
267+ k.sizes ().size () == ( uint8_t )( MAX_MHA_DIM - ( uint8_t )is_nested) ,
266268 " K tensor has unexpected number of dims, please report a bug to PyTorch." );
267269 TORCH_INTERNAL_ASSERT (
268- k.strides ().size () == MAX_MHA_DIM,
270+ k.strides ().size () == ( uint8_t )( MAX_MHA_DIM - ( uint8_t )is_nested) ,
269271 " K tensor has unexpected number of dims, please report a bug to PyTorch." );
270272 TORCH_INTERNAL_ASSERT (
271- v.sizes ().size () == MAX_MHA_DIM,
273+ v.sizes ().size () == ( uint8_t )( MAX_MHA_DIM - ( uint8_t )is_nested) ,
272274 " V tensor has unexpected number of dims, please report a bug to PyTorch." );
273275 TORCH_INTERNAL_ASSERT (
274- v.strides ().size () == MAX_MHA_DIM,
276+ v.strides ().size () == ( uint8_t )( MAX_MHA_DIM - ( uint8_t )is_nested) ,
275277 " V tensor has unexpected number of dims, please report a bug to PyTorch." );
276278 std::copy (q.sizes ().begin (), q.sizes ().end (), params.q_dim .begin ());
277279 std::copy (q.strides ().begin (), q.strides ().end (), params.q_stride .begin ());
@@ -320,7 +322,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
320322 const std::optional<Tensor>& attn_bias,
321323 double dropout_probability,
322324 bool is_causal,
323- bool return_softmaxstats) {
325+ bool return_softmaxstats,
326+ bool is_nested) {
324327 setMHAParams (
325328 this ->pod ,
326329 b,
@@ -335,7 +338,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
335338 attn_bias,
336339 dropout_probability,
337340 is_causal,
338- return_softmaxstats);
341+ return_softmaxstats,
342+ is_nested);
339343 }
340344};
341345
@@ -1386,7 +1390,8 @@ void run_cudnn_SDP_fprop(
13861390 attn_bias,
13871391 dropout_probability,
13881392 is_causal,
1389- return_softmaxstats);
1393+ return_softmaxstats,
1394+ false );
13901395 auto graph_ptr = getMHAGraphCache_ ().find (key);
13911396 std::shared_ptr<fe::graph::Graph> mha_graph;
13921397 if (graph_ptr) {
@@ -1484,30 +1489,53 @@ void run_cudnn_SDP_fprop_nestedtensor(
14841489 if (return_softmaxstats && !softmaxstats.defined ()) {
14851490 softmaxstats = at::empty ({q.size (0 ), h_q, 1 }, q.options ().dtype (kFloat ));
14861491 }
1487- auto mha_graph = build_graph_nestedtensor (
1492+
1493+ auto key = MHACacheKeyWrapper (
14881494 b,
14891495 h_q,
1490- h_k,
1491- h_v,
1492- s_q,
1493- s_kv,
1496+ s_q, // max-seqlen-q
1497+ s_kv, // max-seqlen-kv
14941498 d_qk,
14951499 d_v,
1496- scaling_factor,
1497- return_softmaxstats,
1498- is_causal,
1499- dropout_probability,
1500- cum_seqlen_q,
1501- cum_seqlen_kv,
15021500 q,
15031501 k,
15041502 v,
15051503 attn_bias,
1506- softmaxstats,
1507- o,
1508- dropoutseed,
1509- dropoutoffset,
1510- handle);
1504+ dropout_probability,
1505+ is_causal,
1506+ return_softmaxstats,
1507+ true );
1508+ auto graph_ptr = getMHAGraphCache_ ().find (key);
1509+ std::shared_ptr<fe::graph::Graph> mha_graph;
1510+
1511+ if (graph_ptr) {
1512+ mha_graph = *graph_ptr;
1513+ } else {
1514+ mha_graph = build_graph_nestedtensor (
1515+ b,
1516+ h_q,
1517+ h_k,
1518+ h_v,
1519+ s_q,
1520+ s_kv,
1521+ d_qk,
1522+ d_v,
1523+ scaling_factor,
1524+ return_softmaxstats,
1525+ is_causal,
1526+ dropout_probability,
1527+ cum_seqlen_q,
1528+ cum_seqlen_kv,
1529+ q,
1530+ k,
1531+ v,
1532+ attn_bias,
1533+ softmaxstats,
1534+ o,
1535+ dropoutseed,
1536+ dropoutoffset,
1537+ handle);
1538+ }
15111539 auto seqlen_q = at::diff (cum_seqlen_q, 1 , 0 );
15121540 auto seqlen_kv = at::diff (cum_seqlen_kv, 1 , 0 );
15131541 auto rag_q_off = cum_seqlen_q.mul (h_q * d_qk);
@@ -1636,7 +1664,8 @@ void run_cudnn_SDP_bprop(
16361664 attn_bias,
16371665 dropout_probability,
16381666 is_causal,
1639- true );
1667+ true ,
1668+ false );
16401669 auto graph_backward_ptr = getMHAGraphBackwardCache_ ().find (key);
16411670 std::shared_ptr<fe::graph::Graph> mha_graph;
16421671 if (graph_backward_ptr) {
@@ -1761,33 +1790,55 @@ void run_cudnn_SDP_bprop_nestedtensor(
17611790
17621791 cudnnHandle_t handle = getCudnnHandle ();
17631792
1764- auto mha_graph = build_graph_backward_nestedtensor (
1793+ auto key = MHACacheKeyWrapper (
17651794 b,
17661795 h_q,
1767- h_k,
1768- h_v,
1769- s_q,
1770- s_kv,
1796+ s_q, // max-seqlen-q
1797+ s_kv, // max-seqlen-kv
17711798 d_qk,
17721799 d_v,
1773- scaling_factor,
1774- is_causal,
1775- dropout_probability,
1776- cum_seqlen_q,
1777- cum_seqlen_kv,
17781800 q,
17791801 k,
17801802 v,
17811803 attn_bias,
1782- o,
1783- dO_,
1784- softmaxstats,
1785- dQ,
1786- dK,
1787- dV,
1788- dropoutseed,
1789- dropoutoffset,
1790- handle);
1804+ dropout_probability,
1805+ is_causal,
1806+ true ,
1807+ true );
1808+ auto graph_ptr = getMHAGraphCache_ ().find (key);
1809+ std::shared_ptr<fe::graph::Graph> mha_graph;
1810+
1811+ if (graph_ptr) {
1812+ mha_graph = *graph_ptr;
1813+ } else {
1814+ mha_graph = build_graph_backward_nestedtensor (
1815+ b,
1816+ h_q,
1817+ h_k,
1818+ h_v,
1819+ s_q,
1820+ s_kv,
1821+ d_qk,
1822+ d_v,
1823+ scaling_factor,
1824+ is_causal,
1825+ dropout_probability,
1826+ cum_seqlen_q,
1827+ cum_seqlen_kv,
1828+ q,
1829+ k,
1830+ v,
1831+ attn_bias,
1832+ o,
1833+ dO_,
1834+ softmaxstats,
1835+ dQ,
1836+ dK,
1837+ dV,
1838+ dropoutseed,
1839+ dropoutoffset,
1840+ handle);
1841+ }
17911842
17921843 std::unordered_map<int64_t , void *> variant_pack = {
17931844 // inputs
0 commit comments