@@ -76,6 +76,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7676 )
7777 _testing .assert_onnx_program (onnx_program )
7878
79+ def test_sdpa_with_bool_attn_mask (self ):
80+ class ScaledDotProductAttention (torch .nn .Module ):
81+ def forward (self , query , key , value , attn_mask ):
82+ return torch .nn .functional .scaled_dot_product_attention ( # pylint: disable=not-callable
83+ query , key , value , attn_mask = attn_mask
84+ )
85+
86+ model = ScaledDotProductAttention ()
87+ attn_mask = torch .ones (2 , 4 , 8 , 8 ).bool () # boolean mask for attention
88+ attn_mask [0 , 0 , 0 , :] = False # masking an entire row (padding token)
89+ query = key = value = torch .randn (2 , 4 , 8 , 16 )
90+
91+ onnx_program = torch .onnx .export (
92+ model ,
93+ (query , key , value , attn_mask ),
94+ input_names = ["query" , "key" , "value" , "attn_mask" ],
95+ output_names = ["output" ],
96+ opset_version = 18 ,
97+ dynamo = True ,
98+ )
99+ _testing .assert_onnx_program (onnx_program )
100+
79101
80102if __name__ == "__main__" :
81103 unittest .main ()
0 commit comments