Skip to content

Commit bdda7ee

Browse files
authored
Merge pull request #35 from ChEB-AI/code_documentation
Code documentation
2 parents 0176517 + cdf5989 commit bdda7ee

37 files changed

+2925
-631
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ cython_debug/
161161
#.idea/
162162

163163
# configs/ # commented as new configs can be added as a part of a feature
164+
164165
/.idea
165166
/data
166167
/logs

chebai/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
11
import os
2+
from typing import Any
23

34
import torch
45

6+
# Get the absolute path of the current file's directory
57
MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
68

79

810
class CustomTensor(torch.Tensor):
9-
def __new__(cls, data):
11+
"""
12+
A custom tensor class inheriting from `torch.Tensor`.
13+
14+
This class allows for the creation of tensors using the provided data.
15+
16+
Attributes:
17+
data (Any): The data to be converted into a tensor.
18+
"""
19+
20+
def __new__(cls, data: Any) -> "CustomTensor":
21+
"""
22+
Creates a new instance of CustomTensor.
23+
24+
Args:
25+
data (Any): The data to be converted into a tensor.
26+
27+
Returns:
28+
CustomTensor: A tensor containing the provided data.
29+
"""
1030
return torch.tensor(data)

chebai/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from chebai.cli import cli
22

33
if __name__ == "__main__":
4+
"""
5+
Entry point for the CLI application.
6+
7+
This script calls the `cli` function from the `chebai.cli` module
8+
when executed as the main program.
9+
"""
410
cli()

chebai/callbacks.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,76 @@
11
import json
22
import os
3+
from typing import Any, Dict, List, Literal, Union
34

45
import torch
56
from lightning.pytorch.callbacks import BasePredictionWriter
67

78

89
class ChebaiPredictionWriter(BasePredictionWriter):
9-
def __init__(self, output_dir, write_interval, target_file="predictions.json"):
10+
"""
11+
A custom prediction writer for saving batch and epoch predictions during model training.
12+
13+
This class inherits from `BasePredictionWriter` and is designed to save predictions
14+
in a specified output directory at specified intervals.
15+
16+
Args:
17+
output_dir (str): The directory where predictions will be saved.
18+
write_interval (str): The interval at which predictions will be written.
19+
target_file (str): The name of the file where epoch predictions will be saved (default: "predictions.json").
20+
"""
21+
22+
def __init__(
23+
self,
24+
output_dir: str,
25+
write_interval: Literal["batch", "epoch", "batch_and_epoch"],
26+
target_file: str = "predictions.json",
27+
) -> None:
1028
super().__init__(write_interval)
1129
self.output_dir = output_dir
1230
self.target_file = target_file
1331

1432
def write_on_batch_end(
1533
self,
16-
trainer,
17-
pl_module,
18-
prediction,
19-
batch_indices,
20-
batch,
21-
batch_idx,
22-
dataloader_idx,
23-
):
24-
outpath = os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt")
25-
os.makedirs(outpath, exist_ok=True)
34+
trainer: Any,
35+
pl_module: Any,
36+
prediction: Union[torch.Tensor, List[torch.Tensor]],
37+
batch_indices: List[int],
38+
batch: Any,
39+
batch_idx: int,
40+
dataloader_idx: int,
41+
) -> None:
42+
"""
43+
Saves batch predictions at the end of each batch.
44+
45+
Args:
46+
trainer (Any): The trainer instance.
47+
pl_module (Any): The LightningModule instance.
48+
prediction (Union[torch.Tensor, List[torch.Tensor]]): The prediction output from the model.
49+
batch_indices (List[int]): The indices of the batch.
50+
batch (Any): The current batch.
51+
batch_idx (int): The index of the batch.
52+
dataloader_idx (int): The index of the dataloader.
53+
"""
54+
outpath = os.path.join(self.output_dir, str(dataloader_idx), f"{batch_idx}.pt")
55+
os.makedirs(os.path.dirname(outpath), exist_ok=True)
2656
torch.save(prediction, outpath)
2757

28-
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
58+
def write_on_epoch_end(
59+
self,
60+
trainer: Any,
61+
pl_module: Any,
62+
predictions: List[Dict[str, Any]],
63+
batch_indices: List[int],
64+
) -> None:
65+
"""
66+
Saves all predictions at the end of each epoch in a JSON file.
67+
68+
Args:
69+
trainer (Any): The trainer instance.
70+
pl_module (Any): The LightningModule instance.
71+
predictions (List[Dict[str, Any]]): The list of prediction outputs from the model.
72+
batch_indices (List[int]): The indices of the batches.
73+
"""
2974
pred_list = []
3075
for p in predictions:
3176
idents = p["data"]["idents"]
@@ -35,7 +80,7 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
3580
else:
3681
labels = [None for _ in idents]
3782
output = torch.sigmoid(p["output"]["logits"]).tolist()
38-
for i, l, p in zip(idents, labels, output):
39-
pred_list.append(dict(ident=i, labels=l, predictions=p))
83+
for i, l, o in zip(idents, labels, output):
84+
pred_list.append(dict(ident=i, labels=l, predictions=o))
4085
with open(os.path.join(self.output_dir, self.target_file), "wt") as fout:
4186
json.dump(pred_list, fout)

chebai/callbacks/epoch_metrics.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,38 @@
22
import torchmetrics
33

44

5-
def custom_reduce_fx(input):
5+
def custom_reduce_fx(input: torch.Tensor) -> torch.Tensor:
6+
"""
7+
Custom reduction function for distributed training.
8+
9+
Args:
10+
input (torch.Tensor): The input tensor to be reduced.
11+
12+
Returns:
13+
torch.Tensor: The reduced tensor.
14+
"""
615
print(f"called reduce (device: {input.device})")
716
return torch.sum(input, dim=0)
817

918

1019
class MacroF1(torchmetrics.Metric):
11-
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
20+
"""
21+
Computes the Macro F1 score, which is the unweighted mean of F1 scores for each class.
22+
This implementation differs from torchmetrics.classification.MultilabelF1Score in the behaviour for undefined
23+
values (i.e., classes where TP+FN=0). The torchmetrics implementation sets these classes to a default value.
24+
Here, the mean is only taken over classes which have at least one positive sample.
25+
26+
Args:
27+
num_labels (int): Number of classes/labels.
28+
dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
29+
before returning the value at the step. Default: False.
30+
threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
31+
Default: 0.5.
32+
"""
33+
34+
def __init__(
35+
self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
36+
):
1237
super().__init__(dist_sync_on_step=dist_sync_on_step)
1338

1439
self.add_state(
@@ -28,15 +53,29 @@ def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
2853
)
2954
self.threshold = threshold
3055

31-
def update(self, preds: torch.Tensor, labels: torch.Tensor):
56+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
57+
"""
58+
Update the state (TPs, Positive Predictions, Positive labels) with the current batch of predictions and labels.
59+
60+
Args:
61+
preds (torch.Tensor): Predictions from the model.
62+
labels (torch.Tensor): Ground truth labels.
63+
"""
3264
tps = torch.sum(
3365
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
3466
)
3567
self.true_positives += tps
3668
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)
3769
self.positive_labels += torch.sum(labels, dim=0)
3870

39-
def compute(self):
71+
def compute(self) -> torch.Tensor:
72+
"""
73+
Compute the Macro F1 score.
74+
75+
Returns:
76+
torch.Tensor: The computed Macro F1 score.
77+
"""
78+
4079
# ignore classes without positive labels
4180
# classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0),
4281
# which is propagated to the classwise_f1 and then turned into 0
@@ -50,14 +89,22 @@ def compute(self):
5089

5190

5291
class BalancedAccuracy(torchmetrics.Metric):
53-
"""Balanced Accuracy = (TPR + TNR) / 2 = ( TP/(TP + FN) + (TN)/(TN + FP) ) / 2
54-
55-
This metric computes the balanced accuracy, which is the average of true positive rate (TPR)
56-
and true negative rate (TNR). It is useful for imbalanced datasets where the classes are not
57-
represented equally.
92+
"""
93+
Computes the Balanced Accuracy, which is the average of true positive rate (TPR) and true negative rate (TNR).
94+
Useful for imbalanced datasets.
95+
Balanced Accuracy = (TPR + TNR)/2 = (TP/(TP + FN) + (TN)/(TN + FP))/2
96+
97+
Args:
98+
num_labels (int): Number of classes/labels.
99+
dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
100+
before returning the value at the step. Default: False.
101+
threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
102+
Default: 0.5.
58103
"""
59104

60-
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
105+
def __init__(
106+
self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
107+
):
61108
super().__init__(dist_sync_on_step=dist_sync_on_step)
62109

63110
self.add_state(
@@ -86,8 +133,14 @@ def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
86133

87134
self.threshold = threshold
88135

89-
def update(self, preds: torch.Tensor, labels: torch.Tensor):
90-
"""Update the TPs, TNs ,FPs and FNs"""
136+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
137+
"""
138+
Update the state (TPs, TNs, FPs, FNs) with the current batch of predictions and labels.
139+
140+
Args:
141+
preds (torch.Tensor): Predictions from the model.
142+
labels (torch.Tensor): Ground truth labels.
143+
"""
91144

92145
# Size: Batch_size x Num_of_Classes;
93146
# summing over 1st dimension (dim=0), gives us the True positives per class
@@ -110,9 +163,13 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor):
110163
self.true_negatives += tns
111164
self.false_negatives += fns
112165

113-
def compute(self):
114-
"""Compute the average value of Balanced accuracy from each batch"""
166+
def compute(self) -> torch.Tensor:
167+
"""
168+
Compute the Balanced Accuracy.
115169
170+
Returns:
171+
torch.Tensor: The computed Balanced Accuracy.
172+
"""
116173
tpr = self.true_positives / (self.true_positives + self.false_negatives)
117174
tnr = self.true_negatives / (self.true_negatives + self.false_positives)
118175
# Convert the nan values to 0

chebai/callbacks/model_checkpoint.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,25 @@
1010

1111

1212
class CustomModelCheckpoint(ModelCheckpoint):
13-
"""Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the
14-
same directory as the other logs"""
13+
"""
14+
Custom checkpoint class that resolves checkpoint paths to ensure checkpoints are saved in the same directory
15+
as other logs when using CustomLogger.
16+
Inherits from PyTorch Lightning's ModelCheckpoint class.
17+
"""
1518

16-
def setup(
17-
self, trainer: "Trainer", pl_module: "LightningModule", stage: str
18-
) -> None:
19-
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir"""
19+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
20+
"""
21+
Setup the directory path for saving checkpoints. If the directory path is not set, it resolves the checkpoint
22+
directory using the custom logger's directory.
23+
24+
Note:
25+
Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir
26+
27+
Args:
28+
trainer (Trainer): The Trainer instance.
29+
pl_module (LightningModule): The LightningModule instance.
30+
stage (str): The stage of training (e.g., 'fit').
31+
"""
2032
if self.dirpath is not None:
2133
self.dirpath = None
2234
dirpath = self.__resolve_ckpt_dir(trainer)
@@ -26,16 +38,36 @@ def setup(
2638
self.__warn_if_dir_not_empty(self.dirpath)
2739

2840
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
29-
"""Same as in parent class, duplicated because method in parent class is not accessible"""
41+
"""
42+
Warn if the checkpoint directory is not empty.
43+
44+
Note:
45+
Same as in parent class, duplicated because method in parent class is not accessible
46+
47+
Args:
48+
dirpath (_PATH): The path to the checkpoint directory.
49+
"""
3050
if (
3151
self.save_top_k != 0
3252
and _is_dir(self._fs, dirpath, strict=True)
3353
and len(self._fs.ls(dirpath)) > 0
3454
):
3555
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
3656

37-
def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
38-
"""Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs"""
57+
def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH:
58+
"""
59+
Resolve the checkpoint directory path, ensuring compatibility with WandbLogger by saving checkpoints
60+
in the same directory as Wandb logs.
61+
62+
Note:
63+
Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs
64+
65+
Args:
66+
trainer (Trainer): The Trainer instance.
67+
68+
Returns:
69+
_PATH: The resolved checkpoint directory path.
70+
"""
3971
rank_zero_info(f"Resolving checkpoint dir (custom)")
4072
if self.dirpath is not None:
4173
# short circuit if dirpath was passed to ModelCheckpoint

0 commit comments

Comments
 (0)