Skip to content

Commit 37a4ad8

Browse files
fix: fix remake_buffer adjoint when tangent is an array
1 parent 295edb6 commit 37a4ad8

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,19 @@ function ChainRulesCore.rrule(
8585
f′ = NoTangent()
8686
indp′ = NoTangent()
8787

88-
tunable = selected_tangents(buf′.tunable, tunable_idxs)
89-
discrete = selected_tangents(buf′.discrete, disc_idxs)
90-
constant = selected_tangents(buf′.constant, const_idxs)
91-
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
88+
if buf′ isa AbstractArray
89+
tunable = selected_tangents(buf′, tunable_idxs)
90+
discrete = constant = nonnumeric = NoTangent()
91+
vals′ = map(i -> buf′[i.idx], idxs)
92+
else
93+
tunable = selected_tangents(buf′.tunable, tunable_idxs)
94+
discrete = selected_tangents(buf′.discrete, disc_idxs)
95+
constant = selected_tangents(buf′.constant, const_idxs)
96+
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
97+
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
98+
end
9299
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
93100
idxs′ = NoTangent()
94-
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
95101
return f′, indp′, oldbuf′, idxs′, vals′
96102
end
97103
end

0 commit comments

Comments
 (0)