Skip to content

Commit 68f6a24

Browse files
committed
test for chebai cli, trains mlp
1 parent e379733 commit 68f6a24

File tree

4 files changed

+84
-0
lines changed

4 files changed

+84
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
from lightning.pytorch.core.datamodule import LightningDataModule
3+
from torch.utils.data import DataLoader
4+
5+
from chebai.preprocessing.collate import RaggedCollator
6+
7+
8+
class MyLightningDataModule(LightningDataModule):
9+
def __init__(self):
10+
super().__init__()
11+
self._num_of_labels = None
12+
self._feature_vector_size = None
13+
self.collator = RaggedCollator()
14+
15+
def prepare_data(self):
16+
pass
17+
18+
def setup(self, stage=None):
19+
self._num_of_labels = 10
20+
self._feature_vector_size = 20
21+
print(f"Number of labels: {self._num_of_labels}")
22+
print(f"Number of features: {self._feature_vector_size}")
23+
24+
@property
25+
def num_of_labels(self):
26+
return self._num_of_labels
27+
28+
@property
29+
def feature_vector_size(self):
30+
return self._feature_vector_size
31+
32+
def train_dataloader(self):
33+
assert self.feature_vector_size is not None, "feature_vector_size must be set"
34+
# Dummy dataset for example purposes
35+
36+
datalist = [
37+
{
38+
"features": torch.randn(self._feature_vector_size),
39+
"labels": torch.randint(0, 2, (self._num_of_labels,), dtype=torch.bool),
40+
"ident": i,
41+
"group": None,
42+
}
43+
for i in range(100)
44+
]
45+
46+
return DataLoader(datalist, batch_size=32, collate_fn=self.collator)

tests/unit/cli/__init__.py

Whitespace-only changes.

tests/unit/cli/mock_dm_config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
class_path: chebai.preprocessing.datasets.mock_dm.MyLightningDataModule

tests/unit/cli/testCLI.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import unittest
2+
3+
from chebai.cli import ChebaiCLI
4+
5+
6+
class TestChebaiCLI(unittest.TestCase):
7+
def setUp(self):
8+
self.cli_args = [
9+
"fit",
10+
"--trainer=configs/training/default_trainer.yml",
11+
"--model=configs/model/ffn.yml",
12+
"--model.init_args.hidden_layers=[10]",
13+
"--model.train_metrics=configs/metrics/micro-macro-f1.yml",
14+
"--model.test_metrics=configs/metrics/micro-macro-f1.yml",
15+
"--model.val_metrics=configs/metrics/micro-macro-f1.yml",
16+
"--data=tests/unit/cli/mock_dm_config.yml",
17+
"--model.pass_loss_kwargs=false",
18+
"--trainer.min_epochs=1",
19+
"--trainer.max_epochs=1",
20+
"--model.criterion=configs/loss/bce.yml",
21+
"--model.criterion.init_args.beta=0.99",
22+
]
23+
24+
def test_mlp_on_chebai_cli(self):
25+
# Instantiate ChebaiCLI and ensure no exceptions are raised
26+
try:
27+
ChebaiCLI(
28+
args=self.cli_args,
29+
save_config_kwargs={"config_filename": "lightning_config.yaml"},
30+
parser_kwargs={"parser_mode": "omegaconf"},
31+
)
32+
except Exception as e:
33+
self.fail(f"ChebaiCLI raised an unexpected exception: {e}")
34+
35+
36+
if __name__ == "__main__":
37+
unittest.main()

0 commit comments

Comments
 (0)