Skip to content

Commit 96e9c42

Browse files
author
sfluegel
committed
additions to documentation
1 parent 2ec99a3 commit 96e9c42

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def custom_reduce_fx(input: torch.Tensor) -> torch.Tensor:
1919
class MacroF1(torchmetrics.Metric):
2020
"""
2121
Computes the Macro F1 score, which is the unweighted mean of F1 scores for each class.
22+
This implementation differs from torchmetrics.classification.MultilabelF1Score in the behaviour for undefined
23+
values (i.e., classes where TP+FN=0). The torchmetrics implementation sets these classes to a default value.
24+
Here, the mean is only taken over classes which have at least one positive sample.
2225
2326
Args:
2427
num_labels (int): Number of classes/labels.

chebai/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def __init__(self, save_config_kwargs: dict, parser_kwargs: dict):
2828

2929
def add_arguments_to_parser(self, parser: LightningArgumentParser):
3030
"""
31-
Add custom arguments to the argument parser.
31+
Link input parameters that are used by different classes (e.g. number of labels)
32+
see https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html#argument-linking
3233
3334
Args:
3435
parser (LightningArgumentParser): Argument parser instance.

chebai/loss/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class DisjointLoss(ImplicationLoss):
160160
Disjoint Loss module, extending ImplicationLoss.
161161
162162
Args:
163-
path_to_disjointness (str): Path to the disjointness data file.
163+
path_to_disjointness (str): Path to the disjointness data file (a csv file containing pairs of disjoint classes)
164164
data_extractor (Union[_ChEBIDataExtractor, LabeledUnlabeledMixed]): Data extractor for labels.
165165
base_loss (torch.nn.Module, optional): Base loss function. Defaults to None.
166166
disjoint_loss_weight (float, optional): Weight of disjointness loss. Defaults to 100.

chebai/preprocessing/datasets/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
229229
dataset = dataset[: self.data_limit]
230230
return DataLoader(
231231
dataset,
232-
collate_fn=self.reader.collater,
232+
collate_fn=self.reader.collator,
233233
batch_size=self.batch_size,
234234
**kwargs,
235235
)
@@ -312,8 +312,8 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
312312
Returns the validation DataLoader.
313313
314314
Args:
315-
*args: Additional positional arguments.
316-
**kwargs: Additional keyword arguments.
315+
*args: Additional positional arguments (unused).
316+
**kwargs: Additional keyword arguments, passed to dataloader().
317317
318318
Returns:
319319
Union[DataLoader, List[DataLoader]]: A DataLoader object for validation data.
@@ -331,8 +331,8 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
331331
Returns the test DataLoader.
332332
333333
Args:
334-
*args: Additional positional arguments.
335-
**kwargs: Additional keyword arguments.
334+
*args: Additional positional arguments (unused).
335+
**kwargs: Additional keyword arguments, passed to dataloader().
336336
337337
Returns:
338338
Union[DataLoader, List[DataLoader]]: A DataLoader object for test data.
@@ -346,8 +346,8 @@ def predict_dataloader(
346346
Returns the predict DataLoader.
347347
348348
Args:
349-
*args: Additional positional arguments.
350-
**kwargs: Additional keyword arguments.
349+
*args: Additional positional arguments (unused).
350+
**kwargs: Additional keyword arguments, passed to dataloader().
351351
352352
Returns:
353353
Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data.
@@ -520,7 +520,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
520520
]
521521
return DataLoader(
522522
dataset,
523-
collate_fn=self.reader.collater,
523+
collate_fn=self.reader.collator,
524524
batch_size=self.batch_size,
525525
**kwargs,
526526
)

0 commit comments

Comments
 (0)