Skip to content

Commit c7437e6

Browse files
committed
Fix loss calls by casting target to float due to recent changes
1 parent 076e4ba commit c7437e6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

biapy/engine/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __call__(self, y_pred, y_true):
379379
res_metrics[self.metric_names[pred_ch_start]].append(self.metric_func[pred_ch_start](_y_pred_class, _y_true[:, 1]))
380380
else:
381381
y_pred_slice = pd[:, pred_ch_start:pred_ch_end]
382-
y_true_slice = _y_true[:, gt_ch_start:gt_ch_end]
382+
y_true_slice = _y_true[:, gt_ch_start:gt_ch_end].float()
383383
if y_pred_slice.shape[1] != y_true_slice.shape[1] and "Db" == channel and db_val_type == "discretize":
384384
y_pred_slice = torch.argmax(y_pred_slice, dim=1).unsqueeze(1).float()
385385
y_true_slice = y_true_slice.float()
@@ -1299,7 +1299,7 @@ def __call__(self, y_pred, y_true):
12991299
gt_ch_end = pred_ch_end
13001300

13011301
y_pred_slice = pd[:, pred_ch_start:pred_ch_end]
1302-
y_true_slice = y_true[:, gt_ch_start:gt_ch_end]
1302+
y_true_slice = y_true[:, gt_ch_start:gt_ch_end].float()
13031303

13041304
# element-wise mask you wanted to use (float on same device)
13051305
mask_vals = self.channel_extra_opts.get(channel, {}).get("mask_values", False)

0 commit comments

Comments
 (0)