Skip to content

Commit 56081a3

Browse files
committed
add autodoc to docs/
1 parent 7a0c9fb commit 56081a3

File tree

12 files changed

+358
-110
lines changed

12 files changed

+358
-110
lines changed

docs/source/conf.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,32 @@
99
# If extensions (or modules to document with autodoc) are in another directory,
1010
# add these directories to sys.path here. If the directory is relative to the
1111
# documentation root, use os.path.abspath to make it absolute, like shown here.
12-
#
13-
# import os
14-
# import sys
15-
# sys.path.insert(0, os.path.abspath('.'))
12+
from pathlib import Path
13+
import sys
14+
15+
module_path = Path(__file__).parents[2].resolve()
16+
sys.path.insert(0, module_path.as_posix())
1617

1718

1819
# -- Project information -----------------------------------------------------
1920

20-
project = 'nca-models'
21-
copyright = '2024, Henry Krumb'
22-
author = 'Henry Krumb'
21+
project = "NCALab"
22+
copyright = "2025, MECLab TU Darmstadt"
23+
author = "Henry Krumb"
2324

2425

2526
# -- General configuration ---------------------------------------------------
2627

2728
# Add any Sphinx extension module names here, as strings. They can be
2829
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
2930
# ones.
30-
extensions = [
31-
]
31+
extensions = ["sphinx.ext.doctest", "autoapi.extension"]
32+
33+
autoapi_type = "python"
34+
autoapi_dirs = [module_path / "ncalab"]
3235

3336
# Add any paths that contain templates here, relative to this directory.
34-
templates_path = ['_templates']
37+
templates_path = ["_templates"]
3538

3639
# List of patterns, relative to source directory, that match files and
3740
# directories to ignore when looking for source files.
@@ -44,9 +47,9 @@
4447
# The theme to use for HTML and HTML Help pages. See the documentation for
4548
# a list of builtin themes.
4649
#
47-
html_theme = 'alabaster'
50+
html_theme = "alabaster"
4851

4952
# Add any paths that contain custom static files (such as style sheets) here,
5053
# relative to this directory. They are copied after the builtin static files,
5154
# so a file named "default.css" will overwrite the builtin "default.css".
52-
html_static_path = ['_static']
55+
html_static_path = ["_static"]

ncalab/experiment.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from os import PathLike
2+
import json
3+
4+
from .paths import WEIGHTS_PATH
5+
6+
7+
class Experiment:
8+
def __init__(self, model_name: str):
9+
self.model_name = model_name
10+
self.model_path = WEIGHTS_PATH / (model_name + ".pth")
11+
12+
@staticmethod
13+
def load(path: PathLike):
14+
with open(path, "r") as f:
15+
d = json.load(f)
16+
model_name = d["model_name"]
17+
experiment = Experiment(model_name)
18+
return experiment

ncalab/models/basicNCA.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def __init__(
2929

3030

3131
class BasicNCAModel(nn.Module):
32+
"""
33+
Abstract base class for NCA models.
34+
"""
35+
3236
def __init__(
3337
self,
3438
device: torch.device,
@@ -47,20 +51,20 @@ def __init__(
4751
pad_noise: bool = False,
4852
autostepper: Optional[AutoStepper] = None,
4953
):
50-
"""Basic abstract class for NCA models.
51-
52-
Args:
53-
device (device): Pytorch device descriptor.
54-
num_image_channels (int): Number of channels reserved for input image.
55-
num_hidden_channels (int): Number of hidden channels (communication channels).
56-
num_output_channels (int): Number of output channels.
57-
fire_rate (float, optional): Fire rate for stochastic weight update. Defaults to 0.5.
58-
hidden_size (int, optional): Number of neurons in hidden layer. Defaults to 128.
59-
use_alive_mask (bool, optional): Whether to use alive masking during training. Defaults to False.
60-
immutable_image_channels (bool, optional): If image channels should be fixed during inference,
61-
which is the case for most segmentation or classification problems. Defaults to True.
62-
num_learned_filters (int, optional): Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.
63-
dx_noise (float)
54+
"""
55+
Constructor.
56+
57+
:param device [device]: Pytorch device descriptor.
58+
:param num_image_channels [int]: Number of channels reserved for input image.
59+
:param num_hidden_channels [int]: Number of hidden channels (communication channels).
60+
:param num_output_channels [int]: Number of output channels.
61+
:param fire_rate [float]: Fire rate for stochastic weight update. Defaults to 0.5.
62+
:param hidden_size [int]: Number of neurons in hidden layer. Defaults to 128.
63+
:param use_alive_mask [bool]: Whether to use alive masking during training. Defaults to False.
64+
:param immutable_image_channels [bool]: If image channels should be fixed during inference,
65+
which is the case for most segmentation or classification problems. Defaults to True.
66+
:param num_learned_filters [int]: Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.
67+
:param dx_noise [float]:
6468
"""
6569
super(BasicNCAModel, self).__init__()
6670

@@ -130,6 +134,14 @@ def __init__(
130134
self.meta: dict = {}
131135

132136
def prepare_input(self, x):
137+
"""
138+
Preprocess input. Intended to be overwritten by subclass, if preprocessing
139+
is necessary.
140+
141+
:param x [torch.Tensor]: Input tensor to preprocess.
142+
143+
:returns: Processed tensor.
144+
"""
133145
return x
134146

135147
def alive(self, x):

ncalab/models/cascadeNCA.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ def downscale(image, scale):
1919

2020

2121
class CascadeNCA(BasicNCAModel):
22+
"""
23+
Chain multiple instances of the same NCA model, operating at different
24+
image scales.
25+
"""
26+
2227
def __init__(self, backbone: BasicNCAModel, scales: List[int], steps: List[int]):
2328
"""
24-
Chain multiple instances of the same NCA model, operating at different
25-
image scales.
29+
Constructor.
2630
2731
Args:
2832
backbone (Type): _description_

ncalab/models/classificationNCA.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,18 @@ def __init__(
2525
pad_noise: bool = False,
2626
):
2727
"""
28-
Args:
29-
device (torch.device): Compute device.
30-
num_image_channels (int): _description_
31-
num_hidden_channels (int): _description_
32-
num_classes (int): _description_
33-
fire_rate (float, optional): _description_. Defaults to 0.5.
34-
hidden_size (int, optional): _description_. Defaults to 128.
35-
use_alive_mask (bool, optional): _description_. Defaults to False.
36-
immutable_image_channels (bool, optional): _description_. Defaults to True.
37-
learned_filters (int, optional): _description_. Defaults to 0.
38-
pixel_wise_loss (bool, optional): Whether a prediction per pixel is desired, like in self-classifying MNIST. Defaults to False.
39-
filter_padding (str, optional): _description_. Defaults to "reflect".
40-
pad_noise (bool, optional): _description_. Defaults to False.
28+
:param device [torch.device]: Compute device.
29+
:param num_image_channels [int]: _description_
30+
:param num_hidden_channels [int]: _description_
31+
:param num_classes [int]: _description_
32+
:param fire_rate [float]: _description_. Defaults to 0.5.
33+
:param hidden_size [int]: _description_. Defaults to 128.
34+
:param use_alive_mask [bool]: _description_. Defaults to False.
35+
:param immutable_image_channels [bool]: _description_. Defaults to True.
36+
:param learned_filters [int]: _description_. Defaults to 0.
37+
:param pixel_wise_loss [bool]: Whether a prediction per pixel is desired, like in self-classifying MNIST. Defaults to False.
38+
:param filter_padding [str]: _description_. Defaults to "reflect".
39+
:param pad_noise [bool]: _description_. Defaults to False.
4140
"""
4241
super(ClassificationNCAModel, self).__init__(
4342
device,
@@ -59,15 +58,14 @@ def __init__(
5958
def classify(
6059
self, image: torch.Tensor, steps: int = 100, reduce: bool = False
6160
) -> torch.Tensor:
62-
"""_summary_
61+
"""
62+
Predict classification for an input image.
6363
64-
Args:
65-
image (torch.Tensor): Input image.
66-
steps (int, optional): Inference steps. Defaults to 100.
67-
reduce (bool, optional): Return a single softmax probability. Defaults to False.
64+
:param image [torch.Tensor]: Input image.
65+
:param steps [int]: Inference steps. Defaults to 100.
66+
:param reduce [bool]: Return a single softmax probability. Defaults to False.
6867
69-
Returns:
70-
(torch.Tensor): Single class index or vector of logits.
68+
:returns [torch.Tensor]: Single class index or vector of logits.
7169
"""
7270
with torch.no_grad():
7371
x = image.clone()

ncalab/search/search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def search(
135135
)
136136
writer = SummaryWriter(comment=experiment_name)
137137
model = self.model_class(self.device, **model_args)
138+
# TODO: allow k-fold
138139
trainer = BasicNCATrainer(model, **trainer_args)
139140
summary = trainer.train(
140141
dataloader_train,

ncalab/training/kfold.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,22 @@
1111

1212

1313
class TrainValRecord:
14+
"""
15+
Helper class, storing a training / validation data split to generate
16+
respective DataLoader objects.
17+
"""
18+
1419
def __init__(
1520
self,
1621
train: List[str],
1722
val: List[str],
1823
):
24+
"""
25+
Constructor.
26+
27+
:param train (List[str]): List of training image file paths
28+
:param val (List[str]): List of validation image file paths
29+
"""
1930
self.train = train
2031
self.val = val
2132

@@ -26,6 +37,10 @@ def dataloaders(
2637
transform=None,
2738
batch_sizes=None,
2839
):
40+
"""
41+
Generate a pair of training and validation DataLoader objects, based on
42+
a given DataSet subtype.
43+
"""
2944
if batch_sizes is None:
3045
batch_sizes = {"train": 8, "val": 8}
3146
dataset_train = DatasetType(path, self.train, transform)
@@ -45,7 +60,14 @@ def dataloaders(
4560

4661

4762
class SplitDefinition:
63+
"""
64+
Stores a k-fold cross-validation split.
65+
"""
66+
4867
def __init__(self):
68+
"""
69+
Constructor.
70+
"""
4971
self.folds: List[None | TrainValRecord] = []
5072
self.dataloader_test = None
5173

@@ -55,18 +77,21 @@ def read(path: PosixPath):
5577
Reads json files with split definitions, similar to those created by nnUNet.
5678
5779
Format is like
58-
[
59-
{
60-
"train": [ "filename0", "filename1",... ]
61-
"val": [ "filename2", "filename3",... ]
62-
},
63-
{
64-
...
65-
}
66-
]
67-
68-
Args:
69-
path (PosixPath): Path to JSON file containing split definition.
80+
81+
.. highlight:: python
82+
.. code-block:: python
83+
84+
[
85+
{
86+
"train": [ "filename0", "filename1",... ]
87+
"val": [ "filename2", "filename3",... ]
88+
},
89+
{
90+
...
91+
}
92+
]
93+
94+
:param path [PosixPath]: Path to JSON file containing split definition.
7095
"""
7196
with open(path, "r") as f:
7297
d = json.load(f)
@@ -89,9 +114,15 @@ def __getitem__(self, idx) -> TrainValRecord:
89114

90115
class KFoldCrossValidationTrainer:
91116
def __init__(self, trainer: BasicNCATrainer, split: SplitDefinition):
117+
"""
118+
Constructor.
119+
120+
:param trainer [BasicNCATrainer]: BasicNCATrainer, to train each individual fold.
121+
:param split [SplitDefinition]: Definition of the split used for k-fold cross-training.
122+
"""
92123
self.trainer = trainer
93124
self.model_prototype = copy.deepcopy(trainer.nca)
94-
self.model_name = trainer.model_path.with_suffix('')
125+
self.model_name = trainer.model_path.with_suffix("")
95126
self.split = split
96127

97128
def train(
@@ -102,6 +133,18 @@ def train(
102133
batch_sizes: None | Dict = None,
103134
save_every: int | None = None,
104135
) -> List[TrainingSummary]:
136+
"""
137+
Run training loop with a single function call.
138+
139+
:param DatasetType [Type]: Type of dataset class to use.
140+
:param datapath [Path]: _description_
141+
:param transform: Data transform, e.g. initialized via Albumentations.
142+
:param batch_sizes: Dict of batch sizes per set, e.g. {"train": 8, "val": 16}. Defaults to None.
143+
:param save_every [int]: _description_. Defaults to None.
144+
:param plot_function: Plot function override. If None, use model's default. Defaults to None.
145+
146+
:returns [List[TrainingSummary]]: List of TrainingSummary objects, one per fold.
147+
"""
105148
k = len(self.split)
106149
summaries = []
107150
for i in range(k):

0 commit comments

Comments
 (0)