1- from typing import List , Optional
1+ from typing import List , Optional , Any , Tuple
22import logging
33
44from lightning import LightningModule , Trainer
1616
1717class CustomTrainer (Trainer ):
1818 def __init__ (self , * args , ** kwargs ):
19+ """
20+ Initializes the CustomTrainer class, logging additional hyperparameters to the custom logger if specified.
21+
22+ Args:
23+ *args: Positional arguments for the Trainer class.
24+ **kwargs: Keyword arguments for the Trainer class.
25+ """
1926 self .init_args = args
2027 self .init_kwargs = kwargs
2128 super ().__init__ (* args , ** kwargs )
@@ -32,7 +39,17 @@ def __init__(self, *args, **kwargs):
3239 log_kwargs [log_key ] = log_value
3340 self .logger .log_hyperparams (log_kwargs )
3441
35- def _resolve_logging_argument (self , key , value ):
42+ def _resolve_logging_argument (self , key : str , value : Any ) -> Tuple [str , Any ]:
43+ """
44+ Resolves logging arguments, handling nested structures such as lists and complex objects.
45+
46+ Args:
47+ key: The key of the argument.
48+ value: The value of the argument.
49+
50+ Returns:
51+ A tuple containing the resolved key and value.
52+ """
3653 if isinstance (value , list ):
3754 key_value_pairs = [
3855 self ._resolve_logging_argument (f"{ key } _{ i } " , v )
@@ -58,7 +75,17 @@ def predict_from_file(
5875 input_path : _PATH ,
5976 save_to : _PATH = "predictions.csv" ,
6077 classes_path : Optional [_PATH ] = None ,
61- ):
78+ ) -> None :
79+ """
80+ Loads a model from a checkpoint and makes predictions on input data from a file.
81+
82+ Args:
83+ model: The model to use for predictions.
84+ checkpoint_path: Path to the model checkpoint.
85+ input_path: Path to the input file containing SMILES strings.
86+ save_to: Path to save the predictions CSV file.
87+ classes_path: Optional path to a file containing class names.
88+ """
6289 loaded_model = model .__class__ .load_from_checkpoint (checkpoint_path )
6390 with open (input_path , "r" ) as input :
6491 smiles_strings = [inp .strip () for inp in input .readlines ()]
@@ -71,7 +98,19 @@ def predict_from_file(
7198 predictions_df .index = smiles_strings
7299 predictions_df .to_csv (save_to )
73100
74- def _predict_smiles (self , model : LightningModule , smiles : List [str ]):
101+ def _predict_smiles (
102+ self , model : LightningModule , smiles : List [str ]
103+ ) -> torch .Tensor :
104+ """
105+ Predicts the output for a list of SMILES strings using the model.
106+
107+ Args:
108+ model: The model to use for predictions.
109+ smiles: A list of SMILES strings.
110+
111+ Returns:
112+ A tensor containing the predictions.
113+ """
75114 reader = ChemDataReader ()
76115 parsed_smiles = [reader ._read_data (s ) for s in smiles ]
77116 x = pad_sequence (
@@ -91,6 +130,12 @@ def _predict_smiles(self, model: LightningModule, smiles: List[str]):
91130
92131 @property
93132 def log_dir (self ) -> Optional [str ]:
133+ """
134+ Returns the logging directory.
135+
136+ Returns:
137+ The path to the logging directory if available, else the default root directory.
138+ """
94139 if len (self .loggers ) > 0 :
95140 logger = self .loggers [0 ]
96141 if isinstance (logger , WandbLogger ):
0 commit comments