Skip to content

Commit 05b263b

Browse files
committed
Adds Pytest and removes testing scripts
1 parent 9e07e53 commit 05b263b

File tree

5 files changed

+40
-349
lines changed

5 files changed

+40
-349
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ dependencies = [
1212
"mlx>=0.15.2",
1313
"numpy>=2.0.0",
1414
"scikit-learn>=1.5.1",
15-
"skorch>=1.0.0",
1615
"tabulate>=0.9.0",
17-
"torch>=2.5.0",
1816
]
1917

2018
[project.optional-dependencies]

sklx/tests/test_classifier.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
from mlx import nn
3+
from mlx.nn import losses
4+
from sklearn.datasets import make_classification
5+
6+
from sklx.classifier import NeuralNetworkClassifier
7+
8+
9+
def test_neural_network_classifier():
10+
"""
11+
This is just a simple test to make sure the basic usage works.
12+
"""
13+
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
14+
X = X.astype(np.float32)
15+
y = y.astype(np.int64)
16+
17+
class MyModule(nn.Module):
18+
def __init__(self, num_units=10, nonlin=nn.ReLU()):
19+
super().__init__()
20+
self.layers = [
21+
nn.Linear(20, num_units),
22+
nonlin,
23+
nn.Dropout(0.5),
24+
nn.Linear(num_units, num_units),
25+
nn.Linear(num_units, 2),
26+
nn.LogSoftmax(),
27+
]
28+
29+
def __call__(self, X, **kwargs):
30+
for _, layer in enumerate(self.layers):
31+
X = layer(X)
32+
return X
33+
34+
net = NeuralNetworkClassifier(
35+
MyModule, max_epochs=10, lr=0.1, criterion=losses.nll_loss
36+
)
37+
38+
net.fit(X, y)
39+
net.predict_proba(X)

test_sklx.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

test_skorch.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

0 commit comments

Comments
 (0)