Skip to content

Commit 0de5e41

Browse files
committed
chebai - docstring + typehints
1 parent 847c831 commit 0de5e41

File tree

6 files changed

+441
-52
lines changed

6 files changed

+441
-52
lines changed

chebai/__init__.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11
import os
2-
32
import torch
3+
from typing import Any
44

5+
# Get the absolute path of the current file's directory
56
MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
67

78

89
class CustomTensor(torch.Tensor):
9-
def __new__(cls, data):
10+
"""
11+
A custom tensor class inheriting from `torch.Tensor`.
12+
13+
This class allows for the creation of tensors using the provided data.
14+
15+
Attributes:
16+
data (Any): The data to be converted into a tensor.
17+
"""
18+
19+
def __new__(cls, data: Any) -> "CustomTensor":
20+
"""
21+
Creates a new instance of CustomTensor.
22+
23+
Args:
24+
data (Any): The data to be converted into a tensor.
25+
26+
Returns:
27+
CustomTensor: A tensor containing the provided data.
28+
"""
1029
return torch.tensor(data)

chebai/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from chebai.cli import cli
22

33
if __name__ == "__main__":
4+
"""
5+
Entry point for the CLI application.
6+
7+
This script calls the `cli` function from the `chebai.cli` module
8+
when executed as the main program.
9+
"""
410
cli()

chebai/callbacks.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,74 @@
33

44
from lightning.pytorch.callbacks import BasePredictionWriter
55
import torch
6+
from typing import Any, Dict, List, Union
67

78

89
class ChebaiPredictionWriter(BasePredictionWriter):
9-
def __init__(self, output_dir, write_interval, target_file="predictions.json"):
10+
"""
11+
A custom prediction writer for saving batch and epoch predictions during model training.
12+
13+
This class inherits from `BasePredictionWriter` and is designed to save predictions
14+
in a specified output directory at specified intervals.
15+
16+
Args:
17+
output_dir (str): The directory where predictions will be saved.
18+
write_interval (str): The interval at which predictions will be written.
19+
target_file (str): The name of the file where epoch predictions will be saved (default: "predictions.json").
20+
"""
21+
22+
def __init__(
23+
self,
24+
output_dir: str,
25+
write_interval: str,
26+
target_file: str = "predictions.json",
27+
) -> None:
1028
super().__init__(write_interval)
1129
self.output_dir = output_dir
1230
self.target_file = target_file
1331

1432
def write_on_batch_end(
1533
self,
16-
trainer,
17-
pl_module,
18-
prediction,
19-
batch_indices,
20-
batch,
21-
batch_idx,
22-
dataloader_idx,
23-
):
24-
outpath = os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt")
25-
os.makedirs(outpath, exist_ok=True)
34+
trainer: Any,
35+
pl_module: Any,
36+
prediction: Union[torch.Tensor, List[torch.Tensor]],
37+
batch_indices: List[int],
38+
batch: Any,
39+
batch_idx: int,
40+
dataloader_idx: int,
41+
) -> None:
42+
"""
43+
Saves batch predictions at the end of each batch.
44+
45+
Args:
46+
trainer (Any): The trainer instance.
47+
pl_module (Any): The LightningModule instance.
48+
prediction (Union[torch.Tensor, List[torch.Tensor]]): The prediction output from the model.
49+
batch_indices (List[int]): The indices of the batch.
50+
batch (Any): The current batch.
51+
batch_idx (int): The index of the batch.
52+
dataloader_idx (int): The index of the dataloader.
53+
"""
54+
outpath = os.path.join(self.output_dir, str(dataloader_idx), f"{batch_idx}.pt")
55+
os.makedirs(os.path.dirname(outpath), exist_ok=True)
2656
torch.save(prediction, outpath)
2757

28-
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
58+
def write_on_epoch_end(
59+
self,
60+
trainer: Any,
61+
pl_module: Any,
62+
predictions: List[Dict[str, Any]],
63+
batch_indices: List[int],
64+
) -> None:
65+
"""
66+
Saves all predictions at the end of each epoch in a JSON file.
67+
68+
Args:
69+
trainer (Any): The trainer instance.
70+
pl_module (Any): The LightningModule instance.
71+
predictions (List[Dict[str, Any]]): The list of prediction outputs from the model.
72+
batch_indices (List[int]): The indices of the batches.
73+
"""
2974
pred_list = []
3075
for p in predictions:
3176
idents = p["data"]["idents"]
@@ -35,7 +80,7 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
3580
else:
3681
labels = [None for _ in idents]
3782
output = torch.sigmoid(p["output"]["logits"]).tolist()
38-
for i, l, p in zip(idents, labels, output):
39-
pred_list.append(dict(ident=i, labels=l, predictions=p))
83+
for i, l, o in zip(idents, labels, output):
84+
pred_list.append(dict(ident=i, labels=l, predictions=o))
4085
with open(os.path.join(self.output_dir, self.target_file), "wt") as fout:
4186
json.dump(pred_list, fout)

chebai/cli.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,38 @@
11
from typing import Dict, Set
2-
32
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
4-
53
from chebai.trainer.CustomTrainer import CustomTrainer
64

75

86
class ChebaiCLI(LightningCLI):
9-
def __init__(self, *args, **kwargs):
7+
"""
8+
Custom CLI subclass for Chebai project based on PyTorch Lightning's LightningCLI.
9+
10+
Args:
11+
save_config_kwargs (dict): Keyword arguments for saving configuration.
12+
parser_kwargs (dict): Keyword arguments for parser configuration.
13+
14+
Attributes:
15+
save_config_kwargs (dict): Configuration options for saving.
16+
parser_kwargs (dict): Configuration options for the argument parser.
17+
"""
18+
19+
def __init__(self, save_config_kwargs: dict, parser_kwargs: dict):
20+
"""
21+
Initialize ChebaiCLI with custom trainer and configure parser settings.
22+
23+
Args:
24+
save_config_kwargs (dict): Keyword arguments for saving configuration.
25+
parser_kwargs (dict): Keyword arguments for parser configuration.
26+
"""
1027
super().__init__(trainer_class=CustomTrainer, *args, **kwargs)
1128

1229
def add_arguments_to_parser(self, parser: LightningArgumentParser):
30+
"""
31+
Add custom arguments to the argument parser.
32+
33+
Args:
34+
parser (LightningArgumentParser): Argument parser instance.
35+
"""
1336
for kind in ("train", "val", "test"):
1437
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
1538
parser.link_arguments(
@@ -25,7 +48,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
2548

2649
@staticmethod
2750
def subcommands() -> Dict[str, Set[str]]:
28-
"""Defines the list of available subcommands and the arguments to skip."""
51+
"""
52+
Defines the list of available subcommands and the arguments to skip.
53+
54+
Returns:
55+
Dict[str, Set[str]]: Dictionary where keys are subcommands and values are sets of arguments to skip.
56+
"""
2957
return {
3058
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
3159
"validate": {"model", "dataloaders", "datamodule"},
@@ -36,6 +64,9 @@ def subcommands() -> Dict[str, Set[str]]:
3664

3765

3866
def cli():
67+
"""
68+
Main function to instantiate and run the ChebaiCLI.
69+
"""
3970
r = ChebaiCLI(
4071
save_config_kwargs={"config_filename": "lightning_config.yaml"},
4172
parser_kwargs={"parser_mode": "omegaconf"},

0 commit comments

Comments
 (0)