Skip to content

Commit 2a102fe

Browse files
authored
fix: materialize softmax (#1332)
1 parent a93a83a commit 2a102fe

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.117"
4+
version = "0.2.118"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ for (jlop, hloop) in (
77
end
88

99
function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
10+
x = T.(Reactant.materialize_traced_array(x))
1011
max_ = maximum(x; dims)
1112
diff = exp.(x .- max_)
1213
@trace if all(isfinite, max_)
@@ -19,6 +20,7 @@ function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) whe
1920
end
2021

2122
function NNlib.logsoftmax!(out::AnyTracedRArray{T}, x::AbstractArray; dims=1) where {T}
23+
x = T.(Reactant.materialize_traced_array(x))
2224
max_ = maximum(x; dims)
2325
diff = x .- max_
2426
@trace if all(isfinite, max_)

test/nn/nnlib.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,12 @@ end
418418

419419
@test @jit(NNlib.upsample_nearest(x_ra, (2, 2))) NNlib.upsample_nearest(x, (2, 2))
420420
end
421+
422+
@testset "softmax/logsoftmax reshaped input" begin
423+
x = rand(Float32, 3, 4, 5)
424+
x_ra = reshape(Reactant.to_rarray(x), 12, 5)
425+
x = reshape(x, 12, 5)
426+
427+
@test @jit(NNlib.softmax(x_ra)) NNlib.softmax(x)
428+
@test @jit(NNlib.logsoftmax(x_ra)) NNlib.logsoftmax(x)
429+
end

0 commit comments

Comments
 (0)