Skip to content

Commit a699c17

Browse files
voltjiawooway777
authored andcommitted
Enable scaled_dot_product_attention test cases`
1 parent 873daff commit a699c17

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

test/infinicore/ops/scaled_dot_product_attention.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,16 @@
1111
# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim)
1212

1313
_TEST_CASES_DATA = [
14-
((2, 8, 16), (2, 8, 16), (2, 8, 16), None, 0.0, False),
15-
((1, 4, 32), (1, 4, 32), (1, 4, 32), None, 0.0, False),
16-
((2, 6, 12), (2, 6, 12), (2, 6, 12), None, 0.0, True),
17-
((3, 8, 8), (3, 8, 8), (3, 8, 8), None, 0.0, False),
18-
((2, 4, 16), (2, 4, 16), (2, 4, 16), None, 0.0, True),
19-
((1, 2, 64), (1, 2, 64), (1, 2, 64), None, 0.0, False),
14+
((1, 1, 2, 16), (1, 1, 2, 16), (1, 1, 2, 16), None, 0.0, False),
15+
((1, 2, 8, 16), (1, 2, 8, 16), (1, 2, 8, 16), None, 0.0, False),
16+
((1, 1, 4, 32), (1, 1, 4, 32), (1, 1, 4, 32), None, 0.0, False),
17+
((1, 2, 4, 16), (1, 2, 4, 16), (1, 2, 4, 16), None, 0.0, True),
18+
((1, 1, 2, 64), (1, 1, 2, 64), (1, 1, 2, 64), None, 0.0, False),
2019
]
2120

2221
_TOLERANCE_MAP = {
2322
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
24-
infinicore.float32: {"atol": 1e-4, "rtol": 1e-4},
23+
infinicore.float32: {"atol": 1e-3, "rtol": 1e-3},
2524
}
2625
_TENSOR_DTYPES = [infinicore.float16, infinicore.float32]
2726

@@ -68,9 +67,8 @@ def get_test_cases(self):
6867
def torch_operator(self, *args, **kwargs):
6968
return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs)
7069

71-
# def infinicore_operator(self, *args, **kwargs):
72-
# """InfiniCore implementation (operator not yet available)."""
73-
# return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs)
70+
def infinicore_operator(self, *args, **kwargs):
71+
return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs)
7472

7573

7674
def main():

0 commit comments

Comments
 (0)