Skip to content

Commit ff57997

Browse files
authored
Merge pull request #1242 from bogiebro/master
Fix #1241
2 parents dad65a8 + 99d89b0 commit ff57997

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

src/lib/base.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@ function accum(a::AbstractDict, b::AbstractDict)
2929
end
3030

3131
@adjoint function getindex(d::AbstractDict, k)
32-
d[k], function (Δ)
32+
val = d[k]
33+
function dict_getindex_pullback(Δ)
34+
accum_param(__context__, val, Δ) === nothing && return
3335
grad = grad_mut(__context__, d)
3436
grad[k] = accum(get(grad, k, nothing), Δ)
3537
return (grad, nothing)
3638
end
39+
val, dict_getindex_pullback
3740
end
3841

3942
@adjoint! function setindex!(d::AbstractDict, v, k)

test/lib/base.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
@testset "base.jl" begin
2+
@testset "Dict getindex with implicit params" begin
3+
d = Dict{String, Vector{Float64}}("key"=>ones(4))
4+
fn() = d["key"][2]
5+
result1 = gradient(fn, Params([d["key"]]))[d["key"]]
6+
7+
x = d["key"]
8+
fn2() = x[2]
9+
result2 = gradient(fn2, Params([x]))[x]
10+
11+
@test result1 == result2
12+
end
13+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using CUDA: has_cuda
2929
@testset "lib" begin
3030
include("lib/number.jl")
3131
include("lib/lib.jl")
32+
include("lib/base.jl")
3233
include("lib/array.jl")
3334
end
3435

0 commit comments

Comments
 (0)