diff --git a/thunder/executors/cudnn_sdpa.py b/thunder/executors/cudnn_sdpa.py index 00fb76424a..aef74b3225 100644 --- a/thunder/executors/cudnn_sdpa.py +++ b/thunder/executors/cudnn_sdpa.py @@ -97,10 +97,10 @@ def _make_cudnn_sdpa_forward_graph( Offset = None if dropout_p != 0.0: Seed = graph.tensor( - name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32 + name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64 ) Offset = graph.tensor( - name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32 + name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64 ) dropout_tuple = (dropout_p, Seed, Offset) @@ -450,10 +450,10 @@ def _make_cudnn_sdpa_backward_graph( Offset = None if dropout_p != 0.0: Seed = graph.tensor( - name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32 + name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64 ) Offset = graph.tensor( - name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32 + name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT64 ) dropout_tuple = (dropout_p, Seed, Offset)