Skip to content

Commit 847c831

Browse files
committed
trainer : docstring + typehints
1 parent 5a58c1f commit 847c831

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

chebai/trainer/CustomTrainer.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Any, Tuple
22
import logging
33

44
from lightning import LightningModule, Trainer
@@ -16,6 +16,13 @@
1616

1717
class CustomTrainer(Trainer):
1818
def __init__(self, *args, **kwargs):
19+
"""
20+
Initializes the CustomTrainer class, logging additional hyperparameters to the custom logger if specified.
21+
22+
Args:
23+
*args: Positional arguments for the Trainer class.
24+
**kwargs: Keyword arguments for the Trainer class.
25+
"""
1926
self.init_args = args
2027
self.init_kwargs = kwargs
2128
super().__init__(*args, **kwargs)
@@ -32,7 +39,17 @@ def __init__(self, *args, **kwargs):
3239
log_kwargs[log_key] = log_value
3340
self.logger.log_hyperparams(log_kwargs)
3441

35-
def _resolve_logging_argument(self, key, value):
42+
def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]:
43+
"""
44+
Resolves logging arguments, handling nested structures such as lists and complex objects.
45+
46+
Args:
47+
key: The key of the argument.
48+
value: The value of the argument.
49+
50+
Returns:
51+
A tuple containing the resolved key and value.
52+
"""
3653
if isinstance(value, list):
3754
key_value_pairs = [
3855
self._resolve_logging_argument(f"{key}_{i}", v)
@@ -58,7 +75,17 @@ def predict_from_file(
5875
input_path: _PATH,
5976
save_to: _PATH = "predictions.csv",
6077
classes_path: Optional[_PATH] = None,
61-
):
78+
) -> None:
79+
"""
80+
Loads a model from a checkpoint and makes predictions on input data from a file.
81+
82+
Args:
83+
model: The model to use for predictions.
84+
checkpoint_path: Path to the model checkpoint.
85+
input_path: Path to the input file containing SMILES strings.
86+
save_to: Path to save the predictions CSV file.
87+
classes_path: Optional path to a file containing class names.
88+
"""
6289
loaded_model = model.__class__.load_from_checkpoint(checkpoint_path)
6390
with open(input_path, "r") as input:
6491
smiles_strings = [inp.strip() for inp in input.readlines()]
@@ -71,7 +98,19 @@ def predict_from_file(
7198
predictions_df.index = smiles_strings
7299
predictions_df.to_csv(save_to)
73100

74-
def _predict_smiles(self, model: LightningModule, smiles: List[str]):
101+
def _predict_smiles(
102+
self, model: LightningModule, smiles: List[str]
103+
) -> torch.Tensor:
104+
"""
105+
Predicts the output for a list of SMILES strings using the model.
106+
107+
Args:
108+
model: The model to use for predictions.
109+
smiles: A list of SMILES strings.
110+
111+
Returns:
112+
A tensor containing the predictions.
113+
"""
75114
reader = ChemDataReader()
76115
parsed_smiles = [reader._read_data(s) for s in smiles]
77116
x = pad_sequence(
@@ -91,6 +130,12 @@ def _predict_smiles(self, model: LightningModule, smiles: List[str]):
91130

92131
@property
93132
def log_dir(self) -> Optional[str]:
133+
"""
134+
Returns the logging directory.
135+
136+
Returns:
137+
The path to the logging directory if available, else the default root directory.
138+
"""
94139
if len(self.loggers) > 0:
95140
logger = self.loggers[0]
96141
if isinstance(logger, WandbLogger):

0 commit comments

Comments
 (0)