Skip to content

Commit 48c868d

Browse files
authored
Add unit tests (#7)
1 parent d1f0078 commit 48c868d

File tree

9 files changed

+174
-7
lines changed

9 files changed

+174
-7
lines changed

.github/workflows/pytest.yml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
name: Test
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: ["*"]
8+
workflow_dispatch: # allows you to trigger manually
9+
10+
# When this workflow is queued, automatically cancel any previous running
11+
# or pending jobs from the same branch
12+
concurrency:
13+
group: pytest-${{ github.ref }}
14+
cancel-in-progress: true
15+
16+
defaults:
17+
run:
18+
shell: bash -l {0}
19+
20+
jobs:
21+
test:
22+
name: ${{ matrix.os }} Python ${{ matrix.python-version }} NumPy ${{ matrix.numpy-version}}
23+
runs-on: ${{ matrix.os}}
24+
strategy:
25+
fail-fast: false
26+
matrix:
27+
os: [ubuntu-latest]
28+
python-version: ["3.11", "3.14"]
29+
numpy-version: ["2.3.0", latest]
30+
include:
31+
# Test oldest supported Python and NumPy versions
32+
- os: ubuntu-latest
33+
python-version: "3.8"
34+
numpy-version: "1.18.0"
35+
# Test vs. NumPy nightly wheels
36+
- os: ubuntu-latest
37+
python-version: "3.14"
38+
numpy-version: "nightly"
39+
# Test issues re. preinstalled SSL certificates on different OSes
40+
- os: windows-latest
41+
python-version: "3.14"
42+
numpy-version: latest
43+
- os: macos-latest
44+
python-version: "3.14"
45+
numpy-version: latest
46+
47+
steps:
48+
- name: Checkout
49+
uses: actions/checkout@v6
50+
51+
- name: Set up Python ${{ matrix.python-version }}
52+
uses: actions/setup-python@v4
53+
with:
54+
python-version: ${{ matrix.python-version }}
55+
56+
- name: Install pinned NumPy
57+
if: matrix.numpy-version != 'latest' && matrix.numpy-version != 'nightly'
58+
run: python -m pip install numpy==${{ matrix.numpy-version }}
59+
60+
- name: Install nightly NumPy wheels
61+
if: matrix.numpy-version == 'nightly'
62+
run: pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/ numpy
63+
64+
- name: Install package
65+
run: pip install .
66+
67+
- name: Smoke test
68+
run: python -c "import ml_datasets"
69+
70+
- name: Install test dependencies
71+
run: pip install pytest
72+
73+
- name: Run tests
74+
run: pytest

ml_datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
from .loaders.universal_dependencies import ud_ancora_pos_tags, ud_ewtb_pos_tags
99
from .loaders.dbpedia import dbpedia
1010
from .loaders.cmu import cmu
11+
from .loaders.cifar import cifar
12+
from .loaders.wikiner import wikiner

ml_datasets/loaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .cifar import cifar
12
from .cmu import cmu
23
from .dbpedia import dbpedia
34
from .imdb import imdb

ml_datasets/loaders/reuters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from ..util import get_file
55
from .._registry import register_loader
66

7-
8-
URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl"
7+
URL = "https://s3.amazonaws.com/text-datasets/reuters.pkl"
8+
WORD_INDEX_URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl"
99

1010

1111
@register_loader("reuters")
@@ -15,7 +15,7 @@ def reuters():
1515

1616

1717
def get_word_index(path="reuters_word_index.pkl"):
18-
path = get_file(path, origin=URL)
18+
path = get_file(path, origin=WORD_INDEX_URL)
1919
f = open(path, "rb")
2020
data = pickle.load(f, encoding="latin1")
2121
f.close()
@@ -60,7 +60,7 @@ def load_reuters(
6060
# https://raw.githubusercontent.com/fchollet/keras/master/keras/datasets/mnist.py
6161
# Copyright Francois Chollet, Google, others (2015)
6262
# Under MIT license
63-
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl")
63+
path = get_file(path, origin=URL)
6464
f = open(path, "rb")
6565
X, labels = pickle.load(f)
6666
f.close()

ml_datasets/test/__init__.py

Whitespace-only changes.

ml_datasets/test/test_datasets.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# TODO the tests below only verify that the various functions don't crash.
2+
# Expand them to test the actual output contents.
3+
4+
import platform
5+
6+
import pytest
7+
import numpy as np
8+
9+
import ml_datasets
10+
11+
NP_VERSION = tuple(int(x) for x in np.__version__.split(".")[:2])
12+
13+
# FIXME warning on NumPy 2.4 when downloading pre-computed pickles:
14+
# Python or NumPy boolean but got `align=0`.
15+
# Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4)
16+
if NP_VERSION >= (2, 4):
17+
np_24_deprecation = pytest.mark.filterwarnings(
18+
"ignore::numpy.exceptions.VisibleDeprecationWarning",
19+
20+
)
21+
else:
22+
# Note: can't use `condition=NP_VERSION >= (2, 4)` on the decorator directly
23+
# as numpy.exceptions did not exist in old NumPy versions.
24+
np_24_deprecation = lambda x: x
25+
26+
27+
@np_24_deprecation
28+
def test_cifar():
29+
(X_train, y_train), (X_test, y_test) = ml_datasets.cifar()
30+
31+
32+
@pytest.mark.skip(reason="very slow download")
33+
def test_cmu():
34+
train, dev = ml_datasets.cmu()
35+
36+
37+
def test_dbpedia():
38+
train, dev = ml_datasets.dbpedia()
39+
40+
41+
def test_imdb():
42+
train, dev = ml_datasets.imdb()
43+
44+
45+
@np_24_deprecation
46+
def test_mnist():
47+
(X_train, y_train), (X_test, y_test) = ml_datasets.mnist()
48+
49+
50+
@pytest.mark.xfail(reason="403 Forbidden")
51+
def test_quora_questions():
52+
train, dev = ml_datasets.quora_questions()
53+
54+
55+
@np_24_deprecation
56+
def test_reuters():
57+
(X_train, y_train), (X_test, y_test) = ml_datasets.reuters()
58+
59+
60+
@pytest.mark.xfail(platform.system() == "Windows", reason="path issues")
61+
def test_snli():
62+
train, dev = ml_datasets.snli()
63+
64+
65+
@pytest.mark.xfail(reason="no default path")
66+
def test_stack_exchange():
67+
train, dev = ml_datasets.stack_exchange()
68+
69+
70+
def test_ud_ancora_pos_tags():
71+
(train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ancora_pos_tags()
72+
73+
74+
@pytest.mark.xfail(reason="str column where int expected")
75+
def test_ud_ewtb_pos_tags():
76+
(train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ewtb_pos_tags()
77+
78+
79+
@pytest.mark.xfail(reason="no default path")
80+
def test_wikiner():
81+
train, dev = ml_datasets.wikiner()

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[tool.pytest.ini_options]
2+
addopts = "--strict-markers --strict-config -v -r sxfE --color=yes --durations=10"
3+
xfail_strict = true
4+
filterwarnings = [
5+
"error",
6+
# FIXME spurious random download warnings; will cause trouble in downstream CI
7+
"ignore:Implicitly cleaning up <HTTPError 403:ResourceWarning",
8+
"ignore:unclosed <socket.socket:ResourceWarning",
9+
]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cloudpickle>=2.2
2-
numpy>=1.7.0
2+
numpy>=1.18
33
scipy>=1.7.0
44
tqdm>=4.10.0,<5.0.0
55
# Our libraries

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ long_description_content_type = text/markdown
1111
[options]
1212
zip_safe = true
1313
include_package_data = true
14-
python_requires = >=3.6
14+
python_requires = >=3.8
1515
install_requires =
1616
cloudpickle>=2.2
17-
numpy>=1.7.0
17+
numpy>=1.18
1818
tqdm>=4.10.0,<5.0.0
1919
# Our libraries
2020
srsly>=1.0.1,<4.0.0

0 commit comments

Comments
 (0)