Skip to content

Commit 82b365c

Browse files
committed
predict pipeline in dm and lm
1 parent cfbf392 commit 82b365c

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

chebai/models/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,13 @@ def predict_step(
232232
Returns:
233233
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
234234
"""
235-
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
235+
assert isinstance(batch, XYData)
236+
batch = batch.to(self.device)
237+
data = self._process_batch(batch, batch_idx)
238+
labels = data["labels"]
239+
model_output = self(data, **data.get("model_kwargs", dict()))
240+
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
241+
return pr
236242

237243
def _execute(
238244
self,

chebai/preprocessing/datasets/base.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,14 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
339339
for d in tqdm.tqdm(self._load_dict(path), total=lines)
340340
if d["features"] is not None
341341
]
342+
343+
return self._filter_to_token_limit(data)
344+
345+
def _filter_to_token_limit(
346+
self, data: List[Dict[str, Any]]
347+
) -> List[Dict[str, Any]]:
342348
# filter for missing features in resulting data, keep features length below token limit
343-
data = [
349+
return [
344350
val
345351
for val in data
346352
if val["features"] is not None
@@ -349,8 +355,6 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
349355
)
350356
]
351357

352-
return data
353-
354358
def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
355359
"""
356360
Returns the train DataLoader.
@@ -400,10 +404,13 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
400404
Returns:
401405
Union[DataLoader, List[DataLoader]]: A DataLoader object for test data.
402406
"""
407+
403408
return self.dataloader("test", shuffle=False, **kwargs)
404409

405410
def predict_dataloader(
406-
self, *args, **kwargs
411+
self,
412+
smiles_list: List[str],
413+
**kwargs,
407414
) -> Union[DataLoader, List[DataLoader]]:
408415
"""
409416
Returns the predict DataLoader.
@@ -415,7 +422,21 @@ def predict_dataloader(
415422
Returns:
416423
Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data.
417424
"""
418-
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)
425+
426+
data = [
427+
self.reader.to_data(
428+
{"id": f"smiles_{idx}", "features": smiles, "labels": None}
429+
)
430+
for idx, smiles in enumerate(smiles_list)
431+
]
432+
data = self._filter_to_token_limit(data)
433+
434+
return DataLoader(
435+
data,
436+
collate_fn=self.reader.collator,
437+
batch_size=self.batch_size,
438+
**kwargs,
439+
)
419440

420441
def prepare_data(self, *args, **kwargs) -> None:
421442
if self._prepare_data_flag != 1:

chebai/trainer/CustomTrainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pandas as pd
66
import torch
7-
from lightning import LightningModule, Trainer
7+
from lightning import Trainer
88
from lightning.fabric.utilities.data import _set_sampler_epoch
99
from lightning.fabric.utilities.types import _PATH
1010
from lightning.pytorch.cli import instantiate_module
@@ -87,6 +87,7 @@ def predict_from_file(
8787
input_path: _PATH,
8888
save_to: _PATH = "predictions.csv",
8989
classes_path: Optional[_PATH] = None,
90+
**kwargs,
9091
) -> None:
9192
"""
9293
Loads a model from a checkpoint and makes predictions on input data from a file.
@@ -114,6 +115,7 @@ def _predict_smiles(
114115
smiles: List[str],
115116
classes_path: Optional[_PATH] = None,
116117
save_to: _PATH = "predictions.csv",
118+
**kwargs,
117119
) -> torch.Tensor:
118120
"""
119121
Predicts the output for a list of SMILES strings using the model.

0 commit comments

Comments
 (0)