Skip to content

Commit d906ad4

Browse files
committed
avoid non_null_labels key in loss kwargs
1 parent 63670dd commit d906ad4

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,12 @@ def _process_input_for_prediction(
444444
Returns:
445445
List[Dict[str, Any]]: Processed input data.
446446
"""
447+
# Add dummy labels because the collate function requires them.
448+
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
449+
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
447450
data = [
448451
self.reader.to_data(
449-
{"id": f"smiles_{idx}", "features": smiles, "labels": None}
452+
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
450453
)
451454
for idx, smiles in enumerate(smiles_list)
452455
]

chebai/result/prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def predict_from_file(
6565

6666
preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings)
6767

68-
predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy())
68+
predictions_df = pd.DataFrame(preds.detach().cpu().numpy())
6969

7070
def _add_class_columns(class_file_path: _PATH):
7171
with open(class_file_path, "r") as f:

0 commit comments

Comments
 (0)