Skip to content

Commit 9bc0c11

Browse files
authored
RHOAIENG-26513: tests(containers): add unittests for verifying data science libraries in workbench images (#1421)
Introduces a new test suite (`libraries_testunits.py`) to validate key data science libraries (NumPy, Pandas, scikit-learn, Matplotlib, PyTorch, TorchVision, and TorchAudio) within workbench container images. Adds integration of these tests in `libraries_test.py`. Extends `docker_utils.container_cp` to accept `PathLike` types.
1 parent ba014b4 commit 9bc0c11

File tree

3 files changed

+171
-1
lines changed

3 files changed

+171
-1
lines changed

tests/containers/docker_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import tarfile
99
import time
10+
from os import PathLike
1011
from typing import TYPE_CHECKING
1112

1213
import podman
@@ -45,7 +46,9 @@ def wait_for_exit(self) -> int:
4546
return container.attrs["State"]["ExitCode"]
4647

4748

48-
def container_cp(container: Container, src: str, dst: str, user: int | None = None, group: int | None = None) -> None:
49+
def container_cp(
50+
container: Container, src: str | PathLike, dst: str, user: int | None = None, group: int | None = None
51+
) -> None:
4952
"""
5053
Copies a directory into a container
5154
From https://stackoverflow.com/questions/46390309/how-to-copy-a-file-from-host-to-container-using-docker-py-docker-sdk
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
import pathlib
4+
from typing import TYPE_CHECKING
5+
6+
from tests.containers import docker_utils
7+
from tests.containers.workbenches.workbench_image_test import WorkbenchContainer, grab_and_check_logs
8+
9+
if TYPE_CHECKING:
10+
import pytest_subtests
11+
12+
from tests.containers.conftest import Image
13+
14+
15+
class TestWorkbenchImage:
16+
"""Tests for workbench images in this repository.
17+
A workbench image is an image running a web IDE that listens on port 8888."""
18+
19+
def test_image_entrypoint_starts(
20+
self, subtests: pytest_subtests.SubTests, jupyterlab_datascience_image: Image
21+
) -> None:
22+
container = WorkbenchContainer(image=jupyterlab_datascience_image.name, user=1000, group_add=[0])
23+
try:
24+
try:
25+
container.start()
26+
# check explicitly that we can connect to the ide running in the workbench
27+
with subtests.test("Attempting to connect to the workbench..."):
28+
container._connect()
29+
unittests = pathlib.Path(__file__).parent / "libraries_testunits.py"
30+
docker_utils.container_cp(container.get_wrapped_container(), unittests, "/opt/app-root/src/")
31+
ecode, stdout = container.exec(
32+
[
33+
"env",
34+
f"IMAGE={jupyterlab_datascience_image.labels['name']}",
35+
"bash",
36+
"-c",
37+
"python3 /opt/app-root/src/libraries_testunits.py",
38+
]
39+
)
40+
stdout_decoded = stdout.decode()
41+
print(stdout_decoded)
42+
assert ecode == 0, stdout_decoded
43+
finally:
44+
# try to grab logs regardless of whether container started or not
45+
grab_and_check_logs(subtests, container)
46+
finally:
47+
docker_utils.NotebookContainer(container).stop(timeout=0)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import os
2+
import unittest
3+
4+
"""This is run inside images by libraries_test.py"""
5+
6+
# Suppress noisy logs from libraries, especially during testing
7+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
8+
os.environ["KMP_WARNINGS"] = "0"
9+
10+
11+
# ruff: noqa: PLC0415 `import` should be at the top-level of a file
12+
class TestDataScienceLibs(unittest.TestCase):
13+
"""A test suite to verify the basic functionality of key data science libraries."""
14+
15+
@classmethod
16+
def setUpClass(cls):
17+
"""Set up data once for all tests in this class."""
18+
print("--- 🧪 Verifying Data Science Environment ---")
19+
cls.image = os.environ["IMAGE"]
20+
print(f"Image: {cls.image}")
21+
22+
def setUp(self):
23+
self.tear_downs = []
24+
25+
def tearDown(self):
26+
"""Clean up resources after all tests in this class have run."""
27+
for tear_down in self.tear_downs:
28+
tear_down()
29+
super().tearDown()
30+
31+
def test_numpy(self):
32+
"""Tests numpy array creation and basic operations."""
33+
import numpy as np # pyright: ignore[reportMissingImports]
34+
35+
arr = np.array([[1, 2], [3, 4]])
36+
self.assertEqual(arr.shape, (2, 2), "Numpy array shape is incorrect.")
37+
self.assertEqual(np.sum(arr), 10, "Numpy sum calculation is incorrect.")
38+
print("✅ NumPy test passed.")
39+
40+
def test_pandas(self):
41+
"""Tests pandas DataFrame creation."""
42+
import pandas as pd # pyright: ignore[reportMissingImports]
43+
44+
df = pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]})
45+
self.assertIsInstance(df, pd.DataFrame, "Object is not a Pandas DataFrame.")
46+
self.assertEqual(df.shape, (2, 2), "Pandas DataFrame shape is incorrect.")
47+
print("✅ Pandas test passed.")
48+
49+
def test_sklearn(self):
50+
"""Tests scikit-learn model fitting."""
51+
from sklearn.cluster import KMeans # pyright: ignore[reportMissingImports]
52+
from sklearn.datasets import make_blobs # pyright: ignore[reportMissingImports]
53+
54+
X, _y = make_blobs(n_samples=100, centers=3, random_state=42)
55+
56+
model = KMeans(n_clusters=3, random_state=42, n_init="auto")
57+
model.fit(X)
58+
self.assertEqual(model.cluster_centers_.shape, (3, 2), "Cluster centers shape is incorrect.")
59+
self.assertIsNotNone(model.labels_, "Scikit-learn model failed to fit.")
60+
print("✅ Scikit-learn test passed.")
61+
62+
def test_matplotlib(self):
63+
"""Tests matplotlib plot creation and saving to a file."""
64+
import matplotlib.pyplot as plt # pyright: ignore[reportMissingImports]
65+
from sklearn.datasets import make_blobs # pyright: ignore[reportMissingImports]
66+
67+
X, y = make_blobs(n_samples=50, centers=3, n_features=2, random_state=42)
68+
plot_filename = "matplotlib_unittest.png"
69+
70+
fig, ax = plt.subplots()
71+
ax.scatter(X[:, 0], X[:, 1], c=y)
72+
ax.set_title("Matplotlib Unittest")
73+
plt.savefig(plot_filename)
74+
self.tear_downs.append(lambda: os.remove(plot_filename))
75+
plt.close(fig) # Close the figure to free up memory
76+
77+
self.assertTrue(os.path.exists(plot_filename), "Matplotlib did not create the plot file.")
78+
print("✅ Matplotlib test passed.")
79+
80+
def test_torch(self):
81+
"""🧪 Tests basic PyTorch tensor operations."""
82+
if "-pytorch-" not in self.image:
83+
self.skipTest("Not a Torch image")
84+
import torch # pyright: ignore[reportMissingImports]
85+
86+
device = "cuda" if torch.cuda.is_available() else "cpu"
87+
tensor = torch.rand(2, 3, device=device)
88+
self.assertEqual(tensor.shape, (2, 3), "PyTorch tensor shape is incorrect.")
89+
self.assertTrue(str(tensor.device).startswith(device), "Tensor was not created on the correct device.")
90+
print(f"✅ PyTorch test passed (using device: {device}).")
91+
92+
def test_torchvision(self):
93+
"""🧪 Tests torchvision model loading and inference."""
94+
if "-pytorch-" not in self.image:
95+
self.skipTest("Not a Torch image")
96+
import torch # pyright: ignore[reportMissingImports]
97+
import torchvision # pyright: ignore[reportMissingImports]
98+
99+
model = torchvision.models.resnet18(weights=None) # Use weights=None for faster testing
100+
model.eval()
101+
dummy_input = torch.randn(1, 3, 224, 224)
102+
with torch.no_grad():
103+
output = model(dummy_input)
104+
self.assertEqual(output.shape, (1, 1000), "Torchvision model output shape is incorrect.")
105+
print("✅ Torchvision test passed.")
106+
107+
def test_torchaudio(self):
108+
"""🧪 Tests torchaudio waveform generation."""
109+
if "-pytorch-" not in self.image:
110+
self.skipTest("Not a Torch image")
111+
import torchaudio # pyright: ignore[reportMissingImports]
112+
113+
sample_rate = 16000
114+
waveform = torchaudio.functional.generate_sine(440, sample_rate=sample_rate, duration=0.5)
115+
self.assertEqual(waveform.shape, (1, 8000), "Torchaudio waveform shape is incorrect.")
116+
print("✅ Torchaudio test passed.")
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)