Skip to content

Commit 19bdc2f

Browse files
authored
cast labels to int when labels is 2d (#243)
1 parent fd4440d commit 19bdc2f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/cnlpt/data/predictions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def __init__(
7777
# and all the tasks must be classification. This reflects how we structure
7878
# the data during preprocessing.
7979
assert all(t.type == CLASSIFICATION for t in tasks)
80-
task_labels = {t.name: self.raw.label_ids[:, t.index] for t in tasks}
80+
task_labels = {
81+
t.name: self.raw.label_ids[:, t.index].astype(int) for t in tasks
82+
}
8183
else:
8284
assert self.raw.label_ids.ndim == 3
8385
# If our labels are 3 dimensional, then label_ids has shape (batch, max_seq, L)

0 commit comments

Comments
 (0)