Skip to content

Commit 91e5e70

Browse files
chore: handle arrays in MTKParameters pullback
1 parent d7fa2b9 commit 91e5e70

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
77
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
88
function mtp_pullback(dt)
99
dt = unthunk(dt)
10-
(NoTangent(), dt.tunable[1:length(tunables)],
10+
dtunables = dt isa AbstractArray ? dt : dt.tunable
11+
(NoTangent(), dtunables[1:length(tunables)],
1112
ntuple(_ -> NoTangent(), length(args))...)
1213
end
1314
MTK.MTKParameters(tunables, args...), mtp_pullback

0 commit comments

Comments
 (0)