Skip to content

Commit 5a5c892

Browse files
committed
layers only run for hip devices
1 parent 701de1f commit 5a5c892

File tree

5 files changed

+14
-7
lines changed

5 files changed

+14
-7
lines changed

sharktank/tests/layers/ffn_with_iree_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
from sharktank.utils._helpers import run_iree_vs_torch_fx
99
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
10+
from sharktank.utils.testing import is_hip_condition
1011

1112

1213
class FFN(torch.nn.Module):
@@ -24,8 +25,9 @@ def forward(self, x):
2425
return self.w_down(torch.nn.functional.gelu(self.w_up(x)))
2526

2627

28+
@pytest.mark.skipif(not is_hip_condition, reason="Test requires HIP device")
2729
@pytest.mark.parametrize("dtype,atol", [(torch.float32, 1e-4), (torch.float16, 1e-4)])
28-
def test_ffn_iree_vs_eager(dtype, atol):
30+
def test_ffn_mock_iree_vs_eager(dtype, atol):
2931
torch.manual_seed(42)
3032
m = FFN(hidden=64, inter=128, dtype=dtype, activation="silu")
3133
x = torch.randn(2, 8, 64, dtype=dtype)

sharktank/tests/layers/linear_with_iree_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
from sharktank.utils._helpers import run_iree_vs_torch_fx
1010
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
11+
from sharktank.utils.testing import is_hip_condition
1112

1213

1314
class Linear(torch.nn.Module):
@@ -18,9 +19,9 @@ def __init__(self, in_f, out_f, bias=False, dtype=torch.float32):
1819
def forward(self, x):
1920
return self.lin(x)
2021

21-
22+
@pytest.mark.skipif(not is_hip_condition, reason="Test requires HIP device")
2223
@pytest.mark.parametrize("dtype,atol", [(torch.float32, 1e-4), (torch.float16, 1e-4)])
23-
def test_linear_iree_vs_eager(dtype, atol):
24+
def test_linear_mock_iree_vs_eager(dtype, atol):
2425
torch.manual_seed(42)
2526
m = Linear(64, 64, bias=False, dtype=dtype)
2627
x = torch.randn(2, 8, 64, dtype=dtype)

sharktank/tests/layers/output_lm_test_with_iree.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sharktank.types import Dataset, Theta
1313
from sharktank.layers.configs import LlamaModelConfig
1414
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
15-
15+
from sharktank.utils.testing import is_hip_condition
1616

1717
class OutputLMHead(torch.nn.Module):
1818
"""Standalone output_lm_head block extracted from PagedLlmModelV1"""
@@ -86,6 +86,7 @@ def create_output_lm_head_from_irpa(
8686

8787

8888
# Test cases
89+
@pytest.mark.skipif(not is_hip_condition, reason="Test requires HIP device")
8990
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-4)])
9091
def test_output_lm_head_iree_vs_eager(request, dtype, atol):
9192
"""
@@ -116,7 +117,7 @@ def test_output_lm_head_iree_vs_eager(request, dtype, atol):
116117
parameters_path=irpa_path,
117118
)
118119

119-
120+
@pytest.mark.skipif(not is_hip_condition, reason="Test requires HIP device")
120121
def test_output_lm_head_mock():
121122
"""
122123
Mock test with synthetic weights for OutputLMHead functionality.

sharktank/tests/layers/rms_norm_with_iree_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from sharktank.utils._helpers import run_iree_vs_torch_fx
1010
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
11-
11+
from sharktank.utils.testing import is_hip_condition
1212

1313
class RMSNorm(torch.nn.Module):
1414
def __init__(self, hidden=64, eps=1e-5, dtype=torch.float32):
@@ -25,6 +25,7 @@ def forward(self, x):
2525
return y * self.weight # broadcast over last dim
2626

2727

28+
@pytest.mark.skipif(not is_hip_condition, reason="Test requires HIP device")
2829
@pytest.mark.parametrize("dtype,atol", [(torch.float32, 1e-4), (torch.bfloat16, 1e-2)])
2930
def test_rms_norm_iree_vs_eager(dtype, atol):
3031
torch.manual_seed(42)

sharktank/tests/layers/token_embedding_with_iree_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sharktank.types.theta import Dataset
1212
from sharktank.utils._helpers import run_iree_vs_torch_fx, validate_and_get_irpa_path
1313
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
14+
from sharktank.utils.testing import is_hip_condition
1415

1516

1617
class TokenEmbeddingSmall(torch.nn.Module):
@@ -22,7 +23,7 @@ def __init__(self, vocab_size=128, hidden=64, dtype=torch.float32):
2223
def forward(self, ids: torch.Tensor):
2324
return self.weight[ids]
2425

25-
26+
@pytest.mark.skipif(not is_hip_condition, reason="Test requires HIP device")
2627
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-4)])
2728
def test_token_embedding_iree_vs_eager(request, dtype, atol):
2829
torch.manual_seed(42)
@@ -43,6 +44,7 @@ def test_token_embedding_iree_vs_eager(request, dtype, atol):
4344
)
4445

4546

47+
@pytest.mark.skipif(not is_hip_condition, reason="Test requires HIP device")
4648
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-4)])
4749
def test_token_embedding_mock_iree_vs_eager(dtype, atol):
4850
torch.manual_seed(42)

0 commit comments

Comments
 (0)