@@ -3114,14 +3114,20 @@ def test_ctc_decode(self, dtype):
3114
3114
self .assertEqual (standardize_dtype (decoded .dtype ), "int32" )
3115
3115
self .assertEqual (standardize_dtype (scores .dtype ), expected_dtype )
3116
3116
3117
- @parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
3118
- def test_dot_product_attention (self , dtype ):
3117
+ @parameterized .named_parameters (
3118
+ named_product (
3119
+ dtypes = list (combinations (FLOAT_DTYPES , 2 ))
3120
+ + [(dtype , dtype ) for dtype in FLOAT_DTYPES ]
3121
+ )
3122
+ )
3123
+ def test_dot_product_attention (self , dtypes ):
3119
3124
# TODO: Get expected output from jax if `jax.nn.dot_product_attention`
3120
3125
# is available.
3121
- query = knp .ones ((2 , 3 , 3 , 8 ), dtype = dtype )
3122
- key = knp .ones ((2 , 3 , 3 , 8 ), dtype = dtype )
3123
- value = knp .ones ((2 , 3 , 3 , 8 ), dtype = dtype )
3124
- expected_dtype = dtype
3126
+ query_dtype , key_value_dtype = dtypes
3127
+ query = knp .ones ((2 , 3 , 3 , 8 ), dtype = query_dtype )
3128
+ key = knp .ones ((2 , 3 , 3 , 8 ), dtype = key_value_dtype )
3129
+ value = knp .ones ((2 , 3 , 3 , 8 ), dtype = key_value_dtype )
3130
+ expected_dtype = backend .result_type (* dtypes )
3125
3131
3126
3132
self .assertDType (
3127
3133
knn .dot_product_attention (query , key , value ), expected_dtype
0 commit comments