|
1 | 1 | import unittest
|
2 | 2 | import time
|
3 | 3 | import numpy as np
|
| 4 | +import pytest |
| 5 | + |
| 6 | +from art import performance_monitor |
4 | 7 | from art.performance_monitor import ResourceMonitor, PerformanceTimer, HAS_TENSORFLOW, HAS_TORCH
|
5 | 8 |
|
6 | 9 |
|
@@ -42,12 +45,39 @@ def test_performance_timer(self):
|
42 | 45 | self.assertGreater(len(data["cpu_percent"]), 0)
|
43 | 46 |
|
44 | 47 |
|
| 48 | +@pytest.mark.parametrize( |
| 49 | + "has_nvml, gpu_count, expected_has_gpu", |
| 50 | + [ |
| 51 | + # Scenario 1: No NVML, regardless of GPU count -> No GPU detected |
| 52 | + (False, 0, False), |
| 53 | + (False, 1, False), |
| 54 | + (False, 2, False), |
| 55 | +
|
| 56 | + # Scenario 2: NVML available, but no GPUs -> No GPU detected |
| 57 | + (True, 0, False), |
| 58 | +
|
| 59 | + # Scenario 3: NVML available and GPUs present -> GPU detected |
| 60 | + (True, 1, True), |
| 61 | + (True, 2, True), |
| 62 | + ] |
| 63 | +) |
| 64 | +def test_gpu_detection(monkeypatch, has_nvml: bool, gpu_count: int, expected_has_gpu: bool): |
| 65 | + """ |
| 66 | + Test that GPU detection works correctly based on HAS_NVML and GPU_COUNT. |
| 67 | +
|
| 68 | + This test uses parametrization to cover various combinations of NVML |
| 69 | + availability and detected GPU count. |
| 70 | + """ |
| 71 | + # Initialize the ResourceMonitor with the current parameters |
| 72 | + monkeypatch.setattr(performance_monitor, 'HAS_NVML', has_nvml) |
| 73 | + monkeypatch.setattr(performance_monitor, 'GPU_COUNT', gpu_count) |
| 74 | + monitor = ResourceMonitor() |
| 75 | + |
| 76 | + # Assert that the monitor's detected GPU status matches the expected value |
| 77 | + assert monitor.has_gpu == expected_has_gpu |
| 78 | + |
| 79 | + |
45 | 80 | class TestGPUMonitoring(unittest.TestCase):
|
46 |
| - def test_gpu_detection(self): |
47 |
| - """Test that GPU detection works correctly.""" |
48 |
| - monitor = ResourceMonitor() |
49 |
| - # Check if has_gpu is correctly set based on available libraries |
50 |
| - self.assertEqual(monitor.has_gpu, (HAS_TENSORFLOW or HAS_TORCH)) |
51 | 81 |
|
52 | 82 | def test_gpu_data_collection(self):
|
53 | 83 | """Test GPU data is collected when available."""
|
|
0 commit comments