Skip to content

Commit 40491e5

Browse files
committed
fix label None error
1 parent 517a5a2 commit 40491e5

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

chebai/models/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,10 @@ def predict_step(
235235
assert isinstance(batch, XYData)
236236
batch = batch.to(self.device)
237237
data = self._process_batch(batch, batch_idx)
238-
labels = data["labels"]
239238
model_output = self(data, **data.get("model_kwargs", dict()))
239+
240+
# Dummy labels to avoid errors in _get_prediction_and_labels
241+
labels = torch.zeros((len(batch), self.out_dim)).to(self.device)
240242
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
241243
return pr
242244

chebai/result/prediction.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,33 @@ def __init__(self):
2020
def predict_from_file(
2121
self,
2222
checkpoint_path: _PATH,
23-
input_path: _PATH,
23+
smiles_file_path: _PATH,
2424
save_to: _PATH = "predictions.csv",
2525
classes_path: Optional[_PATH] = None,
26+
batch_size: Optional[int] = None,
2627
) -> None:
2728
"""
2829
Loads a model from a checkpoint and makes predictions on input data from a file.
2930
3031
Args:
31-
model: The model to use for predictions.
3232
checkpoint_path: Path to the model checkpoint.
33-
input_path: Path to the input file containing SMILES strings.
33+
smiles_file_path: Path to the input file containing SMILES strings.
3434
save_to: Path to save the predictions CSV file.
35-
classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered).
35+
classes_path: Optional path to a file containing class names:
36+
if no class names are provided, code will try to get the class path
37+
from the datamodule, else the columns will be numbered.
38+
batch_size: Optional batch size for the DataLoader. If not provided,
39+
the default from the datamodule will be used.
3640
"""
37-
with open(input_path, "r") as input:
41+
with open(smiles_file_path, "r") as input:
3842
smiles_strings = [inp.strip() for inp in input.readlines()]
3943

4044
self.predict_smiles(
4145
checkpoint_path,
4246
smiles=smiles_strings,
4347
classes_path=classes_path,
4448
save_to=save_to,
49+
batch_size=batch_size,
4550
)
4651

4752
@torch.inference_mode()
@@ -51,16 +56,24 @@ def predict_smiles(
5156
smiles: List[str],
5257
classes_path: Optional[_PATH] = None,
5358
save_to: Optional[_PATH] = None,
59+
batch_size: Optional[int] = None,
5460
**kwargs,
5561
) -> torch.Tensor | None:
5662
"""
5763
Predicts the output for a list of SMILES strings using the model.
5864
5965
Args:
60-
model: The model to use for predictions.
66+
checkpoint_path: Path to the model checkpoint.
6167
smiles: A list of SMILES strings.
62-
63-
Returns:
68+
classes_path: Optional path to a file containing class names. If no class
69+
names are provided, code will try to get the class path from the datamodule,
70+
else the columns will be numbered.
71+
save_to: Optional path to save the predictions CSV file. If not provided,
72+
predictions will be returned as a tensor.
73+
batch_size: Optional batch size for the DataLoader. If not provided, the default
74+
from the datamodule will be used.
75+
76+
Returns: (if save_to is None)
6477
A tensor containing the predictions.
6578
"""
6679
ckpt_file = torch.load(
@@ -71,10 +84,13 @@ def predict_smiles(
7184
dm_hparams.pop("splits_file_path")
7285
dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams)
7386
print(f"Loaded datamodule class: {dm.__class__.__name__}")
87+
if batch_size is not None:
88+
dm.batch_size = batch_size
7489

7590
model_hparams = ckpt_file["hyper_parameters"]
7691
model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams)
7792
model.eval()
93+
# TODO: Enable torch.compile when supported
7894
# model = torch.compile(model)
7995
print(f"Loaded model class: {model.__class__.__name__}")
8096

0 commit comments

Comments
 (0)