@@ -97,10 +97,10 @@ def _make_cudnn_sdpa_forward_graph(
97
97
Offset = None
98
98
if dropout_p != 0.0 :
99
99
Seed = graph .tensor (
100
- name = "Seed" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT32
100
+ name = "Seed" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT64
101
101
)
102
102
Offset = graph .tensor (
103
- name = "Offset" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT32
103
+ name = "Offset" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT64
104
104
)
105
105
dropout_tuple = (dropout_p , Seed , Offset )
106
106
@@ -450,10 +450,10 @@ def _make_cudnn_sdpa_backward_graph(
450
450
Offset = None
451
451
if dropout_p != 0.0 :
452
452
Seed = graph .tensor (
453
- name = "Seed" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT32
453
+ name = "Seed" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT64
454
454
)
455
455
Offset = graph .tensor (
456
- name = "Offset" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT32
456
+ name = "Offset" , dim = scalar_dim_stride , stride = scalar_dim_stride , data_type = cudnn .data_type .INT64
457
457
)
458
458
dropout_tuple = (dropout_p , Seed , Offset )
459
459
0 commit comments