|
8 | 8 | import functools
|
9 | 9 | import gc
|
10 | 10 | import importlib.util
|
| 11 | +import os |
11 | 12 | import urllib.error
|
12 | 13 |
|
| 14 | + |
13 | 15 | _has_isaac = importlib.util.find_spec("isaacgym") is not None
|
14 | 16 |
|
15 | 17 | if _has_isaac:
|
|
19 | 21 | from torchrl.envs.libs.isaacgym import IsaacGymEnv
|
20 | 22 | import argparse
|
21 | 23 | import importlib
|
22 |
| -import os |
23 | 24 |
|
24 | 25 | import time
|
25 | 26 | import urllib
|
@@ -2414,6 +2415,28 @@ def test_env_device(self, env_name, frame_skip, transformed_out, device):
|
2414 | 2415 | @pytest.mark.parametrize("device", get_available_devices())
|
2415 | 2416 | @pytest.mark.parametrize("envname", ["fast"])
|
2416 | 2417 | class TestBrax:
|
| 2418 | + @pytest.fixture(autouse=True) |
| 2419 | + def _setup_jax(self): |
| 2420 | + """Configure JAX for proper GPU initialization.""" |
| 2421 | + import os |
| 2422 | + |
| 2423 | + import jax |
| 2424 | + |
| 2425 | + # Set JAX environment variables for better GPU handling |
| 2426 | + os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") |
| 2427 | + os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform") |
| 2428 | + os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") |
| 2429 | + |
| 2430 | + # Try to initialize JAX with GPU, fallback to CPU if it fails |
| 2431 | + try: |
| 2432 | + jax.devices() |
| 2433 | + except Exception: |
| 2434 | + # Fallback to CPU |
| 2435 | + os.environ["JAX_PLATFORM_NAME"] = "cpu" |
| 2436 | + jax.config.update("jax_platform_name", "cpu") |
| 2437 | + |
| 2438 | + yield |
| 2439 | + |
2417 | 2440 | @pytest.mark.parametrize("requires_grad", [False, True])
|
2418 | 2441 | def test_brax_constructor(self, envname, requires_grad, device):
|
2419 | 2442 | env0 = BraxEnv(envname, requires_grad=requires_grad, device=device)
|
@@ -2545,6 +2568,75 @@ def make_brax():
|
2545 | 2568 | tensordict = env.rollout(3)
|
2546 | 2569 | assert tensordict.shape == torch.Size([n, *batch_size, 3])
|
2547 | 2570 |
|
| 2571 | + def test_brax_memory_leak(self, envname, device): |
| 2572 | + """Test memory usage with different cache clearing strategies.""" |
| 2573 | + import psutil |
| 2574 | + |
| 2575 | + process = psutil.Process(os.getpid()) |
| 2576 | + env = BraxEnv( |
| 2577 | + envname, |
| 2578 | + batch_size=[10], |
| 2579 | + requires_grad=True, |
| 2580 | + device=device, |
| 2581 | + ) |
| 2582 | + env.clear_cache() |
| 2583 | + gc.collect() |
| 2584 | + env.set_seed(0) |
| 2585 | + next_td = env.reset() |
| 2586 | + num_steps = 200 |
| 2587 | + policy = TensorDictModule( |
| 2588 | + torch.nn.Linear( |
| 2589 | + env.observation_spec[env.observation_keys[0]].shape[-1], |
| 2590 | + env.action_spec.shape[-1], |
| 2591 | + device=device, |
| 2592 | + ), |
| 2593 | + in_keys=env.observation_keys[:1], |
| 2594 | + out_keys=["action"], |
| 2595 | + ) |
| 2596 | + initial_memory = process.memory_info().rss / 1024 / 1024 # MB |
| 2597 | + for i in range(num_steps): |
| 2598 | + policy(next_td) |
| 2599 | + out_td, next_td = env.step_and_maybe_reset(next_td) |
| 2600 | + if i % 50 == 0: |
| 2601 | + loss = out_td["next", "observation"].sum() |
| 2602 | + loss.backward() |
| 2603 | + next_td = next_td.detach().clone() |
| 2604 | + # gc.collect() |
| 2605 | + final_memory = process.memory_info().rss / 1024 / 1024 # MB |
| 2606 | + memory_increase = final_memory - initial_memory |
| 2607 | + assert ( |
| 2608 | + memory_increase < 100 |
| 2609 | + ), f"Memory leak with automatic clearing: {memory_increase:.2f} MB" |
| 2610 | + |
| 2611 | + def test_brax_cache_clearing(self, envname, device): |
| 2612 | + env = BraxEnv(envname, batch_size=[1], requires_grad=True, device=device) |
| 2613 | + env.clear_cache() |
| 2614 | + for _ in range(5): |
| 2615 | + env.clear_cache() |
| 2616 | + |
| 2617 | + @pytest.mark.parametrize("freq", [10, None, False]) |
| 2618 | + def test_brax_automatic_cache_clearing_parameter(self, envname, device, freq): |
| 2619 | + env = BraxEnv( |
| 2620 | + envname, |
| 2621 | + batch_size=[1], |
| 2622 | + requires_grad=True, |
| 2623 | + device=device, |
| 2624 | + cache_clear_frequency=freq, |
| 2625 | + ) |
| 2626 | + if freq is False: |
| 2627 | + assert env._cache_clear_frequency is False |
| 2628 | + elif freq is None: |
| 2629 | + assert env._cache_clear_frequency == 20 # Default value |
| 2630 | + else: |
| 2631 | + assert env._cache_clear_frequency == freq |
| 2632 | + env.set_seed(0) |
| 2633 | + next_td = env.reset() |
| 2634 | + for i in range(10): |
| 2635 | + action = env.action_spec.rand() |
| 2636 | + next_td["action"] = action |
| 2637 | + out_td, next_td = env.step_and_maybe_reset(next_td) |
| 2638 | + assert env._step_count == i + 1 |
| 2639 | + |
2548 | 2640 |
|
2549 | 2641 | @pytest.mark.skipif(not _has_vmas, reason="vmas not installed")
|
2550 | 2642 | class TestVmas:
|
|
0 commit comments