Skip to content

Commit beaf74e

Browse files
committed
update trainer for protein reader
1 parent 64d7623 commit beaf74e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

chebai/trainer/CustomTrainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.nn.utils.rnn import pad_sequence
1010

1111
from chebai.loggers.custom import CustomLogger
12-
from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader
12+
from chebai.preprocessing.reader import CLS_TOKEN, ProteinDataReader
1313

1414
log = logging.getLogger(__name__)
1515

@@ -99,22 +99,22 @@ def predict_from_file(
9999
predictions_df.to_csv(save_to)
100100

101101
def _predict_smiles(
102-
self, model: LightningModule, smiles: List[str]
102+
self, model: LightningModule, sequence: List[str]
103103
) -> torch.Tensor:
104104
"""
105105
Predicts the output for a list of SMILES strings using the model.
106106
107107
Args:
108108
model: The model to use for predictions.
109-
smiles: A list of SMILES strings.
109+
sequence: Protein sequence.
110110
111111
Returns:
112112
A tensor containing the predictions.
113113
"""
114-
reader = ChemDataReader()
115-
parsed_smiles = [reader._read_data(s) for s in smiles]
114+
reader = ProteinDataReader()
115+
parsed_sequence = [reader._read_data(s) for s in sequence]
116116
x = pad_sequence(
117-
[torch.tensor(a, device=model.device) for a in parsed_smiles],
117+
[torch.tensor(a, device=model.device) for a in parsed_sequence],
118118
batch_first=True,
119119
)
120120
cls_tokens = (

0 commit comments

Comments
 (0)