Skip to content

Commit 284e463

Browse files
Merge pull request #3100 from DhairyaLGandhi/dg/crc3
Allow arrays in `MTKParameters` pullback
2 parents d7fa2b9 + 91e5e70 commit 284e463

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
77
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
88
function mtp_pullback(dt)
99
dt = unthunk(dt)
10-
(NoTangent(), dt.tunable[1:length(tunables)],
10+
dtunables = dt isa AbstractArray ? dt : dt.tunable
11+
(NoTangent(), dtunables[1:length(tunables)],
1112
ntuple(_ -> NoTangent(), length(args))...)
1213
end
1314
MTK.MTKParameters(tunables, args...), mtp_pullback

0 commit comments

Comments
 (0)