Skip to content

Commit 585bc11

Browse files
committed
Actually use scheduler in reverse rules
1 parent df7a2c5 commit 585bc11

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/utility/diffable_threads.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ dtmap!!(args...; scheduler = Defaults.scheduler[]) = tmap!(args...; scheduler)
1313
# Follows the `map` rrule from ChainRules.jl but specified for the case of one AbstractArray that is being mapped
1414
# https://github.com/JuliaDiff/ChainRules.jl/blob/e245d50a1ae56ce46fc8c1f0fe9b925964f1146e/src/rulesets/Base/base.jl#L243
1515
function ChainRulesCore.rrule(
16-
config::RuleConfig{>:HasReverseMode}, ::typeof(dtmap), f, A::AbstractArray; kwargs...
16+
config::RuleConfig{>:HasReverseMode}, ::typeof(dtmap), f, A::AbstractArray;
17+
scheduler = Defaults.scheduler[]
1718
)
18-
el_rrules = tmap(A; kwargs...) do a
19+
el_rrules = tmap(A; scheduler) do a
1920
rrule_via_ad(config, f, a)
2021
end
2122
y = map(first, el_rrules)
@@ -24,7 +25,7 @@ function ChainRulesCore.rrule(
2425

2526
function dtmap_pullback(dy_raw)
2627
dys = unthunk(dy_raw)
27-
backevals = tmap(el_rrules, dys; kwargs...) do el_rrule, dy
28+
backevals = tmap(el_rrules, dys; scheduler) do el_rrule, dy
2829
last(el_rrule)(dy)
2930
end
3031
df = f_projector(sum(first, backevals))

0 commit comments

Comments
 (0)