Skip to content

Commit fcdf767

Browse files
authored
Fix _segmentationloss for 3D images (#261)
1 parent 7fc1f58 commit fcdf767

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

FastVision/src/encodings/onehot.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,14 @@ end
3333
# `logitcrossentropy(...; dims = 3)` doesn't work on GPU:
3434

3535
function _segmentationloss(ypreds, ys; kwargs...)
36-
sz = size(ypreds)
37-
ypreds = reshape(ypreds, :, sz[end - 1], sz[end])
38-
ys = reshape(ys, :, size(ys, 3), size(ys, 4))
36+
sz_preds = size(ypreds)
37+
ypreds = reshape(ypreds, :, sz_preds[end - 1], sz_preds[end])
38+
sz = size(ys)
39+
ys = reshape(ys, :, sz[end - 1], sz[end])
3940
Flux.Losses.logitcrossentropy(ypreds, ys; dims = 2, kwargs...)
4041
end
42+
43+
@testset "segmentationloss" begin
44+
@test _segmentationloss(zeros(10, 10, 3, 5), zeros(10, 10, 3, 5)) == 0
45+
@test _segmentationloss(zeros(10, 10, 10, 3, 5), zeros(10, 10, 10, 3, 5)) == 0
46+
end

0 commit comments

Comments
 (0)