Skip to content

Commit 3d3d05e

Browse files
committed
Have seed and offset in int64 for cudnn-frontend SDPA
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 280c57e commit 3d3d05e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

thunder/executors/cudnn_sdpa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ def _make_cudnn_sdpa_forward_graph(
9797
Offset = None
9898
if dropout_p != 0.0:
9999
Seed = graph.tensor(
100-
name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32
100+
name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64
101101
)
102102
Offset = graph.tensor(
103-
name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32
103+
name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64
104104
)
105105
dropout_tuple = (dropout_p, Seed, Offset)
106106

@@ -450,10 +450,10 @@ def _make_cudnn_sdpa_backward_graph(
450450
Offset = None
451451
if dropout_p != 0.0:
452452
Seed = graph.tensor(
453-
name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32
453+
name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64
454454
)
455455
Offset = graph.tensor(
456-
name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32
456+
name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64
457457
)
458458
dropout_tuple = (dropout_p, Seed, Offset)
459459

0 commit comments

Comments
 (0)