|
11 | 11 | # q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) |
12 | 12 |
|
13 | 13 | _TEST_CASES_DATA = [ |
14 | | - ((2, 8, 16), (2, 8, 16), (2, 8, 16), None, 0.0, False), |
15 | | - ((1, 4, 32), (1, 4, 32), (1, 4, 32), None, 0.0, False), |
16 | | - ((2, 6, 12), (2, 6, 12), (2, 6, 12), None, 0.0, True), |
17 | | - ((3, 8, 8), (3, 8, 8), (3, 8, 8), None, 0.0, False), |
18 | | - ((2, 4, 16), (2, 4, 16), (2, 4, 16), None, 0.0, True), |
19 | | - ((1, 2, 64), (1, 2, 64), (1, 2, 64), None, 0.0, False), |
| 14 | + ((1, 1, 2, 16), (1, 1, 2, 16), (1, 1, 2, 16), None, 0.0, False), |
| 15 | + ((1, 2, 8, 16), (1, 2, 8, 16), (1, 2, 8, 16), None, 0.0, False), |
| 16 | + ((1, 1, 4, 32), (1, 1, 4, 32), (1, 1, 4, 32), None, 0.0, False), |
| 17 | + ((1, 2, 4, 16), (1, 2, 4, 16), (1, 2, 4, 16), None, 0.0, True), |
| 18 | + ((1, 1, 2, 64), (1, 1, 2, 64), (1, 1, 2, 64), None, 0.0, False), |
20 | 19 | ] |
21 | 20 |
|
22 | 21 | _TOLERANCE_MAP = { |
23 | 22 | infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, |
24 | | - infinicore.float32: {"atol": 1e-4, "rtol": 1e-4}, |
| 23 | + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, |
25 | 24 | } |
26 | 25 | _TENSOR_DTYPES = [infinicore.float16, infinicore.float32] |
27 | 26 |
|
@@ -68,9 +67,8 @@ def get_test_cases(self): |
68 | 67 | def torch_operator(self, *args, **kwargs): |
69 | 68 | return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs) |
70 | 69 |
|
71 | | - # def infinicore_operator(self, *args, **kwargs): |
72 | | - # """InfiniCore implementation (operator not yet available).""" |
73 | | - # return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) |
| 70 | + def infinicore_operator(self, *args, **kwargs): |
| 71 | + return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) |
74 | 72 |
|
75 | 73 |
|
76 | 74 | def main(): |
|
0 commit comments