Skip to content

Commit 3a7dc65

Browse files
committed
Allow older ChainRulesCore
1 parent 680807d commit 3a7dc65

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.30"
4+
version = "0.12.31"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
98
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
109
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -21,7 +20,6 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
2120

2221
[compat]
2322
ArrayInterface = "3.1.9"
24-
ChainRulesCore = "0.10"
2523
DocStringExtensions = "0.8"
2624
IfElse = "0.1"
2725
OffsetArrays = "1.4.1"

src/simdfunctionals/vmap_grad_rrule.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11

22
import .ChainRulesCore
33

4+
if isdefined(ChainRulesCore, :ZeroTangent)
5+
const ChainRulesZero = ChainRulesCore.ZeroTangent
6+
else
7+
const ChainRulesZero = ChainRulesCore.Zero
8+
end
9+
410
function ChainRulesCore.rrule(::typeof(tanh_fast), x)
511
t = tanh_fast(x)
612
= let t = t
7-
y -> (ChainRulesCore.ZeroTangent(), mul_fast(vfnmadd_fast(t, t, one(t)), y))
13+
y -> (ChainRulesZero(), mul_fast(vfnmadd_fast(t, t, one(t)), y))
814
end
915
t, ∂
1016
end
1117
function ChainRulesCore.rrule(::typeof(sigmoid_fast), x)
1218
s = sigmoid_fast(x)
1319
= let s = s
14-
y -> (ChainRulesCore.ZeroTangent(), mul_fast(vfnmadd_fast(s, s, s), y))
20+
y -> (ChainRulesZero(), mul_fast(vfnmadd_fast(s, s, s), y))
1521
end
1622
s, ∂
1723
end
@@ -20,7 +26,7 @@ function ChainRulesCore.rrule(::typeof(relu), v)
2026
cmp = v < z
2127
r = ifelse(cmp, z, v)
2228
= let cmp = cmp
23-
y -> (ChainRulesCore.ZeroTangent(), ifelse(cmp, zero(y), y))
29+
y -> (ChainRulesZero(), ifelse(cmp, zero(y), y))
2430
end
2531
r, ∂
2632
end
@@ -64,7 +70,7 @@ end
6470
@generated function (b::SIMDMapBack{K,T})(Δ::A) where {K,T,A}
6571
preloop = Expr(:block, :(jacs = b.jacs))
6672
loop_body = Expr(:block, :(Δᵢ = Δ[i]))
67-
ret = Expr(:tuple, ChainRulesCore.ZeroTangent(), ChainRulesCore.ZeroTangent())
73+
ret = Expr(:tuple, ChainRulesZero(), ChainRulesZero())
6874
for k 1:K
6975
jₖ = Symbol(:j_, k)
7076
push!(preloop.args, :($jₖ = jacs[$k]))

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ const START_TIME = time()
1414
@time @testset "LoopVectorization.jl" begin
1515

1616
@time if LOOPVECTORIZATION_TEST == "all" || LOOPVECTORIZATION_TEST == "part1"
17-
@time Aqua.test_all(LoopVectorization, stale_deps = false)
18-
@time Aqua.test_stale_deps(LoopVectorization, ignore = [:ChainRulesCore])
17+
@time Aqua.test_all(LoopVectorization)
1918
# @test isempty(detect_unbound_args(LoopVectorization))
2019

2120
@time include("printmethods.jl")

0 commit comments

Comments
 (0)