Skip to content

Commit dca0104

Browse files
ydshiehBernardZach
authored andcommitted
Simplify running tests in a subprocess (huggingface#34213)
* check * check * check * check * add docstring --------- Co-authored-by: ydshieh <[email protected]>
1 parent fd480e0 commit dca0104

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

src/transformers/testing_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,6 +2366,46 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
23662366
test_case.fail(f'{results["error"]}')
23672367

23682368

2369+
def run_test_using_subprocess(func):
2370+
"""
2371+
To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
2372+
issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
2373+
"""
2374+
import pytest
2375+
2376+
@functools.wraps(func)
2377+
def wrapper(*args, **kwargs):
2378+
if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
2379+
func(*args, **kwargs)
2380+
else:
2381+
test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
2382+
try:
2383+
import copy
2384+
2385+
env = copy.deepcopy(os.environ)
2386+
env["_INSIDE_SUB_PROCESS"] = "1"
2387+
2388+
# If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
2389+
if "pytestconfig" in kwargs:
2390+
command = list(kwargs["pytestconfig"].invocation_params.args)
2391+
for idx, x in enumerate(command):
2392+
if x in kwargs["pytestconfig"].args:
2393+
test = test.split("::")[1:]
2394+
command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
2395+
command = [f"{sys.executable}", "-m", "pytest"] + command
2396+
command = [x for x in command if x not in ["--no-summary"]]
2397+
# Otherwise, simply run the test with no option at all
2398+
else:
2399+
command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
2400+
2401+
subprocess.run(command, env=env, check=True, capture_output=True)
2402+
except subprocess.CalledProcessError as e:
2403+
exception_message = e.stdout.decode()
2404+
raise pytest.fail(exception_message, pytrace=False)
2405+
2406+
return wrapper
2407+
2408+
23692409
"""
23702410
The following contains utils to run the documentation tests without having to overwrite any files.
23712411

tests/models/imagegpt/test_modeling_imagegpt.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import unittest
1919

2020
from transformers import ImageGPTConfig
21-
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
21+
from transformers.testing_utils import require_torch, require_vision, run_test_using_subprocess, slow, torch_device
2222
from transformers.utils import cached_property, is_torch_available, is_vision_available
2323

2424
from ...generation.test_utils import GenerationTesterMixin
@@ -257,11 +257,9 @@ def _check_scores(self, batch_size, scores, length, config):
257257
self.assertEqual(len(scores), length)
258258
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
259259

260-
@unittest.skip(
261-
reason="After #33632, this test still passes, but many subsequential tests fail with `device-side assert triggered`"
262-
)
260+
@run_test_using_subprocess
263261
def test_beam_search_generate_dict_outputs_use_cache(self):
264-
pass
262+
super().test_beam_search_generate_dict_outputs_use_cache()
265263

266264
def setUp(self):
267265
self.model_tester = ImageGPTModelTester(self)

tests/models/video_llava/test_modeling_video_llava.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@
2828
is_torch_available,
2929
is_vision_available,
3030
)
31-
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
31+
from transformers.testing_utils import (
32+
require_bitsandbytes,
33+
require_torch,
34+
require_torch_gpu,
35+
run_test_using_subprocess,
36+
slow,
37+
torch_device,
38+
)
3239

3340
from ...generation.test_utils import GenerationTesterMixin
3441
from ...test_configuration_common import ConfigTester
@@ -248,9 +255,7 @@ def test_flash_attn_2_fp32_ln(self):
248255
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
249256
pass
250257

251-
@unittest.skip(
252-
reason="After #33533, this still passes, but many subsequential tests fail with `device-side assert triggered`"
253-
)
258+
@run_test_using_subprocess
254259
def test_mixed_input(self):
255260
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
256261
for model_class in self.all_model_classes:

0 commit comments

Comments
 (0)