Skip to content

Commit 9195b21

Browse files
committed
[CI] Separate GPU and CPU tests with pytest markers
Add pytest.mark.gpu to tests that require CUDA, and update run_all.sh to filter tests based on whether running on GPU or CPU machines. Changes: - Register 'gpu' marker in pytest.ini and conftest.py - Add pytest.mark.gpu to ~30 tests that explicitly require CUDA - Update run_all.sh to use GPU_MARKER_FILTER: - GPU jobs (CU_VERSION != cpu): run only pytest.mark.gpu tests - CPU jobs (CU_VERSION = cpu): run all tests except pytest.mark.gpu This significantly reduces GPU machine usage by running only GPU-requiring tests on expensive GPU runners (~30 tests instead of ~2000+). Tests that can run on either device will run on CPU machines only. The optimization can be disabled by setting TORCHRL_GPU_FILTER=0. ghstack-source-id: 9235913 Pull-Request: #3404
1 parent 5bfe43f commit 9195b21

File tree

13 files changed

+60
-4
lines changed

13 files changed

+60
-4
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,28 @@ fi
269269

270270
TORCHRL_TEST_SUITE="${TORCHRL_TEST_SUITE:-all}" # all|distributed|nondistributed
271271

272+
# GPU test filtering: Run GPU-only tests on GPU machines, CPU-only tests on CPU machines.
273+
# This avoids running ~2000+ tests on expensive GPU machines when only ~30 require GPU.
274+
# Tests are marked with @pytest.mark.gpu if they require CUDA.
275+
#
276+
# Set TORCHRL_GPU_FILTER=0 to disable this optimization and run all tests.
277+
#
278+
# We use an array to handle the marker expression properly (avoids quoting issues).
279+
GPU_MARKER_FILTER=()
280+
if [ "${TORCHRL_GPU_FILTER:-1}" = "1" ]; then
281+
if [ "${CU_VERSION:-}" == cpu ]; then
282+
# CPU job: run only tests that do NOT require GPU
283+
GPU_MARKER_FILTER=(-m 'not gpu')
284+
echo "GPU filtering enabled: Running CPU-only tests (excluding @pytest.mark.gpu)"
285+
else
286+
# GPU job: run only tests that require GPU
287+
GPU_MARKER_FILTER=(-m gpu)
288+
echo "GPU filtering enabled: Running GPU-only tests (@pytest.mark.gpu)"
289+
fi
290+
else
291+
echo "GPU filtering disabled: Running all tests"
292+
fi
293+
272294
export PYTORCH_TEST_WITH_SLOW='1'
273295
python -m torch.utils.collect_env
274296

@@ -287,6 +309,7 @@ run_distributed_tests() {
287309
return 1
288310
fi
289311
# Run both test_distributed.py and test_rb_distributed.py (both use torch.distributed)
312+
# Note: distributed tests always run on GPU, no need for GPU_MARKER_FILTER here
290313
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py test/test_rb_distributed.py \
291314
--instafail --durations 200 -vv --capture no \
292315
--timeout=120 --mp_fork_if_no_cuda
@@ -317,12 +340,12 @@ run_non_distributed_tests() {
317340
1)
318341
echo "Running shard 1: test_transforms.py only"
319342
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py \
320-
${common_args}
343+
"${GPU_MARKER_FILTER[@]}" ${common_args}
321344
;;
322345
2)
323346
echo "Running shard 2: test_envs.py and test_collectors.py"
324347
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_envs.py test/test_collectors.py \
325-
${common_args}
348+
"${GPU_MARKER_FILTER[@]}" ${common_args}
326349
;;
327350
3)
328351
echo "Running shard 3: All other tests"
@@ -332,13 +355,13 @@ run_non_distributed_tests() {
332355
--ignore test/test_envs.py \
333356
--ignore test/test_collectors.py \
334357
${xdist_args} \
335-
${common_args}
358+
"${GPU_MARKER_FILTER[@]}" ${common_args}
336359
;;
337360
all|"")
338361
echo "Running all tests (no sharding)"
339362
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
340363
${common_ignores} \
341-
${common_args}
364+
"${GPU_MARKER_FILTER[@]}" ${common_args}
342365
;;
343366
*)
344367
echo "Unknown TORCHRL_TEST_SHARD='${shard}'. Expected: all|1|2|3."

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ addopts =
66
--tb=native
77
markers =
88
unity_editor
9+
slow: mark test as slow to run
10+
gpu: mark test as requiring a GPU (CUDA device)
911
testpaths =
1012
test
1113
xfail_strict = True

test/compile/test_compile_collectors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_compiled_policy(self, collector_cls, compile_policy, device):
7777
collector.shutdown()
7878
del collector
7979

80+
@pytest.mark.gpu
8081
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
8182
@pytest.mark.parametrize(
8283
"collector_cls",

test/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def pytest_runtest_setup(item):
145145

146146
def pytest_configure(config):
147147
config.addinivalue_line("markers", "slow: mark test as slow to run")
148+
config.addinivalue_line(
149+
"markers", "gpu: mark test as requiring a GPU (CUDA device)"
150+
)
148151

149152

150153
def pytest_collection_modifyitems(config, items):

test/llm/test_llm_updaters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def get_open_port():
7272
)
7373

7474

75+
@pytest.mark.gpu
7576
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
7677
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
7778
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -415,6 +416,7 @@ def test_local_llm_specific_features(self, target_vllm_engine):
415416
"See LLM_TEST_ISSUES.md for details.",
416417
strict=False,
417418
)
419+
@pytest.mark.gpu
418420
@pytest.mark.skipif(not _has_ray, reason="missing ray dependencies")
419421
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
420422
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
@@ -611,6 +613,7 @@ def test_weight_sync_vllm_collective_ray(self, request):
611613
ray.shutdown()
612614

613615

616+
@pytest.mark.gpu
614617
@pytest.mark.xfail(
615618
reason="AsyncVLLM tests fail due to Ray placement group timeout. "
616619
"See LLM_TEST_ISSUES.md for details.",

test/llm/test_vllm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def sampling_params():
3939
class TestAsyncVLLMIntegration:
4040
"""Integration tests for AsyncVLLM with real models."""
4141

42+
@pytest.mark.gpu
4243
@pytest.mark.skipif(not _has_vllm, reason="vllm not available")
4344
@pytest.mark.skipif(not _has_ray, reason="ray not available")
4445
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -111,6 +112,7 @@ def test_vllm_api_compatibility(self, sampling_params):
111112
finally:
112113
service.shutdown()
113114

115+
@pytest.mark.gpu
114116
@pytest.mark.skipif(not _has_vllm, reason="vllm not available")
115117
@pytest.mark.skipif(not _has_ray, reason="ray not available")
116118
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

test/llm/test_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,7 @@ def test_log_probs_consistency(
21042104
"See LLM_TEST_ISSUES.md for details.",
21052105
strict=False,
21062106
)
2107+
@pytest.mark.gpu
21072108
@pytest.mark.skipif(not _has_vllm, reason="vllm not available")
21082109
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
21092110
def test_sync_async_vllm_strict_equivalence(

test/test_collectors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
23332333
def _set_seed(self, seed: int | None) -> None:
23342334
...
23352335

2336+
@pytest.mark.gpu
23362337
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
23372338
@pytest.mark.parametrize("env_device", ["cuda:0", "cpu"])
23382339
@pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"])
@@ -2371,6 +2372,7 @@ def test_no_synchronize(self, env_device, storing_device, no_cuda_sync):
23712372
assert u == i, i
23722373
mock_synchronize.assert_not_called()
23732374

2375+
@pytest.mark.gpu
23742376
@pytest.mark.parametrize("device", ["cuda", "cpu"])
23752377
@pytest.mark.parametrize("storing_device", ["cuda", "cpu"])
23762378
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
@@ -3162,6 +3164,7 @@ def test_multi_collector_consistency(
31623164
assert_allclose_td(c2.unsqueeze(0), d2)
31633165

31643166

3167+
@pytest.mark.gpu
31653168
@pytest.mark.skipif(
31663169
not torch.cuda.is_available() and (not has_mps()),
31673170
reason="No casting if no cuda",
@@ -3363,6 +3366,7 @@ def test_param_sync_mixed_device(
33633366
col.shutdown()
33643367
del col
33653368

3369+
@pytest.mark.gpu
33663370
@pytest.mark.skipif(
33673371
not torch.cuda.is_available() or torch.cuda.device_count() < 3,
33683372
reason="requires at least 3 CUDA devices",

test/test_envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def test_auto_spec(self, env_type):
597597
env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals)
598598
env.check_env_specs(tensordict=td.copy())
599599

600+
@pytest.mark.gpu
600601
@pytest.mark.skipif(not torch.cuda.device_count(), reason="No cuda device found.")
601602
@pytest.mark.parametrize("break_when_any_done", [True, False])
602603
def test_auto_cast_to_device(self, break_when_any_done):
@@ -1526,6 +1527,7 @@ def test_parallel_env_with_policy(
15261527
# env_serial.close()
15271528
env0.close(raise_if_closed=False)
15281529

1530+
@pytest.mark.gpu
15291531
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
15301532
@pytest.mark.parametrize("heterogeneous", [False, True])
15311533
def test_transform_env_transform_no_device(
@@ -1638,6 +1640,7 @@ def test_parallel_env_custom_method(self, parallel, maybe_fork_ParallelEnv):
16381640
finally:
16391641
env.close(raise_if_closed=False)
16401642

1643+
@pytest.mark.gpu
16411644
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda to test on")
16421645
@pytest.mark.skipif(not _has_gym, reason="no gym")
16431646
@pytest.mark.parametrize("frame_skip", [4])
@@ -1742,6 +1745,7 @@ def test_parallel_env_cast(
17421745
env_serial.close(raise_if_closed=False)
17431746
env0.close(raise_if_closed=False)
17441747

1748+
@pytest.mark.gpu
17451749
@pytest.mark.skipif(not _has_gym, reason="no gym")
17461750
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda device detected")
17471751
@pytest.mark.parametrize("frame_skip", [4])
@@ -2726,6 +2730,7 @@ def test_marl_group_type(group_type):
27262730
check_marl_grouping(group_type.get_group_map(agent_names), agent_names)
27272731

27282732

2733+
@pytest.mark.gpu
27292734
@pytest.mark.skipif(not torch.cuda.device_count(), reason="No cuda device")
27302735
class TestConcurrentEnvs:
27312736
"""Concurrent parallel envs on multiple procs can interfere."""

test/test_libs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,6 +2157,7 @@ def test_set_seed_and_reset_works(self):
21572157

21582158
assert isinstance(td, TensorDict)
21592159

2160+
@pytest.mark.gpu
21602161
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
21612162
def test_dmcontrol_kwargs_preserved_with_seed(self):
21622163
"""Test that kwargs like camera_id are preserved when seed is provided.
@@ -2182,6 +2183,7 @@ def test_dmcontrol_kwargs_preserved_with_seed(self):
21822183
finally:
21832184
env.close()
21842185

2186+
@pytest.mark.gpu
21852187
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
21862188
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
21872189
@pytest.mark.parametrize("frame_skip", [1, 3])
@@ -2776,6 +2778,7 @@ def test_multithread_env_shutdown(self):
27762778
assert not env.is_closed
27772779
env.close()
27782780

2781+
@pytest.mark.gpu
27792782
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda to test on")
27802783
@pytest.mark.skipif(not _has_gym, reason="no gym")
27812784
@pytest.mark.parametrize("frame_skip", [4])
@@ -2816,6 +2819,7 @@ def test_multithreaded_env_cast(
28162819
assert td_device.device == torch.device(device), env_multithread
28172820
env_multithread.close()
28182821

2822+
@pytest.mark.gpu
28192823
@pytest.mark.skipif(not _has_gym, reason="no gym")
28202824
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda device detected")
28212825
@pytest.mark.parametrize("frame_skip", [4])
@@ -3097,6 +3101,7 @@ def test_brax_automatic_cache_clearing_parameter(self, envname, device, freq):
30973101
out_td, next_td = env.step_and_maybe_reset(next_td)
30983102
assert env._step_count == i + 1
30993103

3104+
@pytest.mark.gpu
31003105
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
31013106
def test_brax_kwargs_preserved_with_seed(self, envname, device):
31023107
"""Test that kwargs like camera_id are preserved when seed is provided.

0 commit comments

Comments
 (0)