Skip to content

Commit 36e6162

Browse files
committed
ffn: fix error for loss kwargs
1 parent f13e935 commit 36e6162

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

chebai/models/ffn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _get_prediction_and_labels(self, data, labels, model_output):
3737
loss_kwargs = data.get("loss_kwargs", dict())
3838
if "non_null_labels" in loss_kwargs:
3939
n = loss_kwargs["non_null_labels"]
40-
d = data[n]
40+
d = d[n]
4141
return torch.sigmoid(d), labels.int() if labels is not None else None
4242

4343
def _process_for_loss(

0 commit comments

Comments
 (0)