@@ -21,7 +21,7 @@ def test_merge_state(seq_len, num_heads, head_dim):
21
21
assert torch .allclose (v_merged , v_merged_std , atol = 1e-2 )
22
22
assert torch .allclose (s_merged , s_merged_std , atol = 1e-2 )
23
23
except GPUArchitectureError as e :
24
- pytest .skip (e . msg )
24
+ pytest .skip (str ( e ) )
25
25
26
26
27
27
@pytest .mark .parametrize ("seq_len" , [2048 ])
@@ -44,7 +44,7 @@ def test_merge_state_in_place(seq_len, num_heads, head_dim):
44
44
assert torch .allclose (s , s_std , atol = 1e-2 )
45
45
46
46
except GPUArchitectureError as e :
47
- pytest .skip (e . msg )
47
+ pytest .skip (str ( e ) )
48
48
49
49
50
50
@pytest .mark .parametrize ("seq_len" , [2048 ])
@@ -63,7 +63,7 @@ def test_merge_states(seq_len, num_states, num_heads, head_dim):
63
63
assert torch .allclose (v_merged , v_merged_std , atol = 1e-2 )
64
64
assert torch .allclose (s_merged , s_merged_std , atol = 1e-2 )
65
65
except GPUArchitectureError as e :
66
- pytest .skip (e . msg )
66
+ pytest .skip (str ( e ) )
67
67
68
68
69
69
@pytest .mark .parametrize ("seq_len" , [2048 ])
@@ -94,4 +94,4 @@ def test_variable_length_merge_states(seq_len, num_heads, head_dim):
94
94
assert torch .allclose (v_merged [i ], v_merged_std , atol = 1e-2 )
95
95
assert torch .allclose (s_merged [i ], s_merged_std , atol = 1e-2 )
96
96
except GPUArchitectureError as e :
97
- pytest .skip (e . msg )
97
+ pytest .skip (str ( e ) )
0 commit comments