We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d7fa2b9 commit 91e5e70Copy full SHA for 91e5e70
ext/MTKChainRulesCoreExt.jl
@@ -7,7 +7,8 @@ import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
7
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
8
function mtp_pullback(dt)
9
dt = unthunk(dt)
10
- (NoTangent(), dt.tunable[1:length(tunables)],
+ dtunables = dt isa AbstractArray ? dt : dt.tunable
11
+ (NoTangent(), dtunables[1:length(tunables)],
12
ntuple(_ -> NoTangent(), length(args))...)
13
end
14
MTK.MTKParameters(tunables, args...), mtp_pullback
0 commit comments