|
| 1 | +import os |
| 2 | +import unittest |
| 3 | + |
| 4 | +"""This is run inside images by libraries_test.py""" |
| 5 | + |
| 6 | +# Suppress noisy logs from libraries, especially during testing |
| 7 | +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
| 8 | +os.environ["KMP_WARNINGS"] = "0" |
| 9 | + |
| 10 | + |
| 11 | +# ruff: noqa: PLC0415 `import` should be at the top-level of a file |
| 12 | +class TestDataScienceLibs(unittest.TestCase): |
| 13 | + """A test suite to verify the basic functionality of key data science libraries.""" |
| 14 | + |
| 15 | + @classmethod |
| 16 | + def setUpClass(cls): |
| 17 | + """Set up data once for all tests in this class.""" |
| 18 | + print("--- 🧪 Verifying Data Science Environment ---") |
| 19 | + cls.image = os.environ["IMAGE"] |
| 20 | + print(f"Image: {cls.image}") |
| 21 | + |
| 22 | + def setUp(self): |
| 23 | + self.tear_downs = [] |
| 24 | + |
| 25 | + def tearDown(self): |
| 26 | + """Clean up resources after all tests in this class have run.""" |
| 27 | + for tear_down in self.tear_downs: |
| 28 | + tear_down() |
| 29 | + super().tearDown() |
| 30 | + |
| 31 | + def test_numpy(self): |
| 32 | + """Tests numpy array creation and basic operations.""" |
| 33 | + import numpy as np # pyright: ignore[reportMissingImports] |
| 34 | + |
| 35 | + arr = np.array([[1, 2], [3, 4]]) |
| 36 | + self.assertEqual(arr.shape, (2, 2), "Numpy array shape is incorrect.") |
| 37 | + self.assertEqual(np.sum(arr), 10, "Numpy sum calculation is incorrect.") |
| 38 | + print("✅ NumPy test passed.") |
| 39 | + |
| 40 | + def test_pandas(self): |
| 41 | + """Tests pandas DataFrame creation.""" |
| 42 | + import pandas as pd # pyright: ignore[reportMissingImports] |
| 43 | + |
| 44 | + df = pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]}) |
| 45 | + self.assertIsInstance(df, pd.DataFrame, "Object is not a Pandas DataFrame.") |
| 46 | + self.assertEqual(df.shape, (2, 2), "Pandas DataFrame shape is incorrect.") |
| 47 | + print("✅ Pandas test passed.") |
| 48 | + |
| 49 | + def test_sklearn(self): |
| 50 | + """Tests scikit-learn model fitting.""" |
| 51 | + from sklearn.cluster import KMeans # pyright: ignore[reportMissingImports] |
| 52 | + from sklearn.datasets import make_blobs # pyright: ignore[reportMissingImports] |
| 53 | + |
| 54 | + X, _y = make_blobs(n_samples=100, centers=3, random_state=42) |
| 55 | + |
| 56 | + model = KMeans(n_clusters=3, random_state=42, n_init="auto") |
| 57 | + model.fit(X) |
| 58 | + self.assertEqual(model.cluster_centers_.shape, (3, 2), "Cluster centers shape is incorrect.") |
| 59 | + self.assertIsNotNone(model.labels_, "Scikit-learn model failed to fit.") |
| 60 | + print("✅ Scikit-learn test passed.") |
| 61 | + |
| 62 | + def test_matplotlib(self): |
| 63 | + """Tests matplotlib plot creation and saving to a file.""" |
| 64 | + import matplotlib.pyplot as plt # pyright: ignore[reportMissingImports] |
| 65 | + from sklearn.datasets import make_blobs # pyright: ignore[reportMissingImports] |
| 66 | + |
| 67 | + X, y = make_blobs(n_samples=50, centers=3, n_features=2, random_state=42) |
| 68 | + plot_filename = "matplotlib_unittest.png" |
| 69 | + |
| 70 | + fig, ax = plt.subplots() |
| 71 | + ax.scatter(X[:, 0], X[:, 1], c=y) |
| 72 | + ax.set_title("Matplotlib Unittest") |
| 73 | + plt.savefig(plot_filename) |
| 74 | + self.tear_downs.append(lambda: os.remove(plot_filename)) |
| 75 | + plt.close(fig) # Close the figure to free up memory |
| 76 | + |
| 77 | + self.assertTrue(os.path.exists(plot_filename), "Matplotlib did not create the plot file.") |
| 78 | + print("✅ Matplotlib test passed.") |
| 79 | + |
| 80 | + def test_torch(self): |
| 81 | + """🧪 Tests basic PyTorch tensor operations.""" |
| 82 | + if "-pytorch-" not in self.image: |
| 83 | + self.skipTest("Not a Torch image") |
| 84 | + import torch # pyright: ignore[reportMissingImports] |
| 85 | + |
| 86 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 87 | + tensor = torch.rand(2, 3, device=device) |
| 88 | + self.assertEqual(tensor.shape, (2, 3), "PyTorch tensor shape is incorrect.") |
| 89 | + self.assertTrue(str(tensor.device).startswith(device), "Tensor was not created on the correct device.") |
| 90 | + print(f"✅ PyTorch test passed (using device: {device}).") |
| 91 | + |
| 92 | + def test_torchvision(self): |
| 93 | + """🧪 Tests torchvision model loading and inference.""" |
| 94 | + if "-pytorch-" not in self.image: |
| 95 | + self.skipTest("Not a Torch image") |
| 96 | + import torch # pyright: ignore[reportMissingImports] |
| 97 | + import torchvision # pyright: ignore[reportMissingImports] |
| 98 | + |
| 99 | + model = torchvision.models.resnet18(weights=None) # Use weights=None for faster testing |
| 100 | + model.eval() |
| 101 | + dummy_input = torch.randn(1, 3, 224, 224) |
| 102 | + with torch.no_grad(): |
| 103 | + output = model(dummy_input) |
| 104 | + self.assertEqual(output.shape, (1, 1000), "Torchvision model output shape is incorrect.") |
| 105 | + print("✅ Torchvision test passed.") |
| 106 | + |
| 107 | + def test_torchaudio(self): |
| 108 | + """🧪 Tests torchaudio waveform generation.""" |
| 109 | + if "-pytorch-" not in self.image: |
| 110 | + self.skipTest("Not a Torch image") |
| 111 | + import torchaudio # pyright: ignore[reportMissingImports] |
| 112 | + |
| 113 | + sample_rate = 16000 |
| 114 | + waveform = torchaudio.functional.generate_sine(440, sample_rate=sample_rate, duration=0.5) |
| 115 | + self.assertEqual(waveform.shape, (1, 8000), "Torchaudio waveform shape is incorrect.") |
| 116 | + print("✅ Torchaudio test passed.") |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == "__main__": |
| 120 | + unittest.main(verbosity=2) |
0 commit comments