Skip to content

Commit cd74e12

Browse files
committed
Merge branch 'master' of github.com:FluxML/FastAI.jl
2 parents 9962126 + fcdf767 commit cd74e12

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

FastVision/src/encodings/onehot.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,14 @@ end
3333
# `logitcrossentropy(...; dims = 3)` doesn't work on GPU:
3434

3535
function _segmentationloss(ypreds, ys; kwargs...)
36-
sz = size(ypreds)
37-
ypreds = reshape(ypreds, :, sz[end - 1], sz[end])
38-
ys = reshape(ys, :, size(ys, 3), size(ys, 4))
36+
sz_preds = size(ypreds)
37+
ypreds = reshape(ypreds, :, sz_preds[end - 1], sz_preds[end])
38+
sz = size(ys)
39+
ys = reshape(ys, :, sz[end - 1], sz[end])
3940
Flux.Losses.logitcrossentropy(ypreds, ys; dims = 2, kwargs...)
4041
end
42+
43+
@testset "segmentationloss" begin
44+
@test _segmentationloss(zeros(10, 10, 3, 5), zeros(10, 10, 3, 5)) == 0
45+
@test _segmentationloss(zeros(10, 10, 10, 3, 5), zeros(10, 10, 10, 3, 5)) == 0
46+
end

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ JLD2 = "0.4"
3737
MLDatasets = "0.7"
3838
MLUtils = "0.2.6"
3939
Parameters = "0.12"
40-
PrettyTables = "1.2"
40+
PrettyTables = "1.2, 2"
4141
Reexport = "1.0"
4242
Requires = "1"
4343
Setfield = "0.7, 0.8"

0 commit comments

Comments
 (0)