Skip to content

Commit 5d1bdd4

Browse files
authored
Skip tests when there are import issues on CI worker (meta-pytorch#132)
* move the test that doesn't run on CI to a new folder * Revert "move the test that doesn't run on CI to a new folder" This reverts commit b2fd5c3. * move the right file... * use pytest skipif * catch all exceptions * change function name * fix * typo
1 parent 33abb3e commit 5d1bdd4

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

tests/unit_tests/actors/test_reference_actor.py renamed to tests/unit_tests/test_reference_actor.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,32 @@
88
Tests for reference_actor.py - compute_logprobs function
99
"""
1010

11+
import unittest
12+
1113
import pytest
1214
import torch
1315

14-
from forge.actors.reference_actor import compute_logprobs
1516

17+
def _import_error():
18+
try:
19+
import forge.actors.reference_actor # noqa: F401
20+
21+
return False
22+
except Exception:
23+
return True
1624

17-
class TestComputeLogprobs:
25+
26+
class TestComputeLogprobs(unittest.TestCase):
1827
"""Test the compute_logprobs utility function."""
1928

29+
@pytest.mark.skipif(
30+
_import_error(),
31+
reason="Import error, likely due to missing dependencies on CI.",
32+
)
2033
def test_compute_logprobs_basic(self):
2134
"""Test basic logprobs computation."""
35+
from forge.actors.reference_actor import compute_logprobs
36+
2237
batch_size = 1
2338
seq_len = 5
2439
vocab_size = 1000
@@ -36,8 +51,14 @@ def test_compute_logprobs_basic(self):
3651
assert result.shape == (batch_size, response_len)
3752
assert torch.all(result <= 0) # Log probabilities should be <= 0
3853

54+
@pytest.mark.skipif(
55+
_import_error(),
56+
reason="Import error, likely due to missing dependencies on CI.",
57+
)
3958
def test_compute_logprobs_with_temperature(self):
4059
"""Test logprobs computation with temperature scaling."""
60+
from forge.actors.reference_actor import compute_logprobs
61+
4162
batch_size = 1
4263
seq_len = 5
4364
vocab_size = 1000
@@ -55,8 +76,14 @@ def test_compute_logprobs_with_temperature(self):
5576
default_result = compute_logprobs(logits, input_ids)
5677
assert not torch.allclose(result, default_result)
5778

79+
@pytest.mark.skipif(
80+
_import_error(),
81+
reason="Import error, likely due to missing dependencies on CI.",
82+
)
5883
def test_compute_logprobs_single_token(self):
5984
"""Test logprobs computation with single token response."""
85+
from forge.actors.reference_actor import compute_logprobs
86+
6087
batch_size = 1
6188
seq_len = 5
6289
vocab_size = 1000
@@ -70,8 +97,14 @@ def test_compute_logprobs_single_token(self):
7097
assert result.shape == (batch_size, response_len)
7198
assert result.numel() == 1 # Single element
7299

100+
@pytest.mark.skipif(
101+
_import_error(),
102+
reason="Import error, likely due to missing dependencies on CI.",
103+
)
73104
def test_compute_logprobs_empty_response(self):
74105
"""Test logprobs computation with empty response."""
106+
from forge.actors.reference_actor import compute_logprobs
107+
75108
batch_size = 1
76109
seq_len = 5
77110
vocab_size = 1000
@@ -84,8 +117,14 @@ def test_compute_logprobs_empty_response(self):
84117

85118
assert result.shape == (batch_size, response_len)
86119

120+
@pytest.mark.skipif(
121+
_import_error(),
122+
reason="Import error, likely due to missing dependencies on CI.",
123+
)
87124
def test_compute_logprobs_empty_prompt(self):
88125
"""Test logprobs computation with empty prompt."""
126+
from forge.actors.reference_actor import compute_logprobs
127+
89128
batch_size = 1
90129
vocab_size = 1000
91130
prompt_len = 0

0 commit comments

Comments
 (0)