Skip to content

Commit b36b2c4

Browse files
committed
Update ChainRules for Zero -> ZeroTangent, and support a.b[constindex].c[index], fixes #278
1 parent 065ce6e commit b36b2c4

File tree

5 files changed

+23
-19
lines changed

5 files changed

+23
-19
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.12.29"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -20,6 +21,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
2021

2122
[compat]
2223
ArrayInterface = "3.1.9"
24+
ChainRulesCore = "0.10"
2325
DocStringExtensions = "0.8"
2426
IfElse = "0.1"
2527
OffsetArrays = "1.4.1"
@@ -30,5 +32,5 @@ Static = "0.2"
3032
StrideArraysCore = "0.1.11"
3133
ThreadingUtilities = "0.4.2"
3234
UnPack = "1"
33-
VectorizationBase = "0.20.4"
35+
VectorizationBase = "0.20.4,0.21"
3436
julia = "1.5"

src/condense_loopset.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ function _append_fields!(t::Expr, body::Expr, sym::Symbol, ::Type{T}) where {T}
77
gf = GlobalRef(Core,:getfield)
88
for f 1:fieldcount(T)
99
TF = fieldtype(T, f)
10+
Base.issingletontype(TF) && continue
1011
gfcall = Expr(:call, gf, sym, f)
11-
if Base.issingletontype(TF)
12-
nothing
13-
elseif fieldcount(TF) 0
12+
if fieldcount(TF) 0
1413
push!(t.args, gfcall)
1514
elseif TF <: DataType
1615
push!(t.args, Expr(:call, Expr(:curly, lv(:StaticType), gfcall)))
@@ -23,7 +22,6 @@ function _append_fields!(t::Expr, body::Expr, sym::Symbol, ::Type{T}) where {T}
2322
return nothing
2423
end
2524
@generated function flatten_to_tuple(r::T) where {T}
26-
numfields = fieldcount(T)
2725
body = Expr(:block, Expr(:meta,:inline))
2826
t = Expr(:tuple)
2927
if Base.issingletontype(T)
@@ -39,10 +37,9 @@ end
3937
body
4038
end
4139
function rebuild_fields(offset::Int, ::Type{T}) where {T}
42-
numfields = fieldcount(T)
4340
gf = GlobalRef(Core,:getfield)
4441
call = (T <: Tuple) ? Expr(:tuple) : Expr(:new, T)
45-
for f 1:numfields
42+
for f 1:fieldcount(T)
4643
TF = fieldtype(T, f)
4744
if Base.issingletontype(TF)
4845
push!(call.args, TF.instance)

src/parse/memory_ops_common.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
dottosym(x) = x
2-
dottosym(x::Expr) = Symbol(dottosym(x.args[1]), "###extractarray###", x.args[2].value)
1+
dottosym(x)::Symbol = x
2+
3+
function dottosym(x::Expr)::Symbol
4+
s1 = dottosym(x.args[1])
5+
xa2 = x.args[2]
6+
xa2 isa QuoteNode ? Symbol(s1, "###extractarray###", xa2.value) : Symbol(s1, "###extractarray###", xa2)
7+
end
38
function extract_array_symbol_from_ref!(ls::LoopSet, ex::Expr, offset1::Int)::Symbol
49
ar = ex.args[1 + offset1]
510
if isa(ar, Symbol)

src/simdfunctionals/vmap_grad_rrule.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ import .ChainRulesCore
44
function ChainRulesCore.rrule(::typeof(tanh_fast), x)
55
t = tanh_fast(x)
66
= let t = t
7-
y -> (ChainRulesCore.Zero(), mul_fast(vfnmadd_fast(t, t, one(t)), y))
7+
y -> (ChainRulesCore.ZeroTangent(), mul_fast(vfnmadd_fast(t, t, one(t)), y))
88
end
99
t, ∂
1010
end
1111
function ChainRulesCore.rrule(::typeof(sigmoid_fast), x)
1212
s = sigmoid_fast(x)
1313
= let s = s
14-
y -> (ChainRulesCore.Zero(), mul_fast(vfnmadd_fast(s, s, s), y))
14+
y -> (ChainRulesCore.ZeroTangent(), mul_fast(vfnmadd_fast(s, s, s), y))
1515
end
1616
s, ∂
1717
end
@@ -20,7 +20,7 @@ function ChainRulesCore.rrule(::typeof(relu), v)
2020
cmp = v < z
2121
r = ifelse(cmp, z, v)
2222
= let cmp = cmp
23-
y -> (ChainRulesCore.Zero(), ifelse(cmp, zero(y), y))
23+
y -> (ChainRulesCore.ZeroTangent(), ifelse(cmp, zero(y), y))
2424
end
2525
r, ∂
2626
end
@@ -64,7 +64,7 @@ end
6464
@generated function (b::SIMDMapBack{K,T})(Δ::A) where {K,T,A}
6565
preloop = Expr(:block, :(jacs = b.jacs))
6666
loop_body = Expr(:block, :(Δᵢ = Δ[i]))
67-
ret = Expr(:tuple, ChainRulesCore.Zero(), ChainRulesCore.Zero())
67+
ret = Expr(:tuple, ChainRulesCore.ZeroTangent(), ChainRulesCore.ZeroTangent())
6868
for k 1:K
6969
jₖ = Symbol(:j_, k)
7070
push!(preloop.args, :($jₖ = jacs[$k]))

test/miscellaneous.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,15 +1069,15 @@ end
10691069
@test grad!(zeros(5), ones(5), ones(3)) grad_avx!(zeros(5), ones(5), ones(3)) grad_avx_base!(zeros(5), ones(5), ones(3)) grad_avx_eval!(zeros(5), ones(5), ones(3))
10701070

10711071
nta = rand(2)
1072-
namedtuple = (a = copy(nta), b = 10.0)
1072+
namedtuple = (a = (1,copy(nta)), b = 10.0)
10731073
@turbo for i in 1:2
1074-
namedtuple.a[i] += namedtuple.b
1074+
namedtuple.a[2][i] += namedtuple.b
10751075
end
1076-
@test namedtuple.a == nta .+ 10
1076+
@test namedtuple.a[2].c == nta .+ 10
10771077

1078-
let A = rand(T, 20, 30); B = rand(T, 20, 30); C = rand(T, 20, 30, 30);
1079-
@test threemulaccum_base(A,B,C) threemulaccum_lv(A,B,C)
1080-
end
1078+
let A = rand(T, 20, 30); B = rand(T, 20, 30); C = rand(T, 20, 30, 30);
1079+
@test threemulaccum_base(A,B,C) threemulaccum_lv(A,B,C)
1080+
end
10811081
end
10821082
for T [Int16, Int32, Int64]
10831083
n = 8sizeof(T) - 1

0 commit comments

Comments
 (0)