Skip to content

Commit b3758ad

Browse files
committed
ZygoteRules -> ChainRules
1 parent 830a879 commit b3758ad

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ authors = ["Michael Abbott"]
44
version = "0.0.8"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
910
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
10-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
11-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1211

1312
[compat]
13+
ChainRulesCore = "0.10.9"
1414
NamedDims = "0.2.16"
1515
OffsetArrays = "1"
16-
ZygoteRules = "0.2"
1716
julia = "1.3"
1817

1918
[extras]

src/LazyStack.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,19 +339,19 @@ function LazyStack.stack(s::Symbol, args...)
339339
ensure_named(data, name_last)
340340
end
341341

342-
#===== Zygote =====#
342+
#===== Gradients =====#
343343

344-
using ZygoteRules: @adjoint
344+
using ChainRulesCore: ChainRulesCore, rrule, NoTangent
345345

346-
@adjoint function stack(vec::AbstractArray{<:AbstractArray{<:Any,IN}}) where {IN}
347-
stack(vec), Δ -> ([view(Δ, ntuple(_->(:),IN)..., Tuple(I)...) for I in eachindex(vec)],)
346+
function ChainRulesCore.rrule(::typeof(stack), vec::AbstractArray{<:AbstractArray{<:Any,IN}}) where {IN}
347+
stack(vec), Δ -> (NoTangent(), [view(Δ, ntuple(_->(:),IN)..., Tuple(I)...) for I in eachindex(vec)],)
348348
end
349349

350-
@adjoint function stack(tup::Tuple{Vararg{<:AbstractArray{<:Any,IN}}}) where {IN}
351-
stack(tup), Δ -> (ntuple(i -> view(Δ, ntuple(_->(:),IN)..., i), length(tup)),)
350+
function ChainRulesCore.rrule(::typeof(stack), tup::Tuple{Vararg{<:AbstractArray{<:Any,IN}}}) where {IN}
351+
stack(tup), Δ -> (NoTangent(), ntuple(i -> view(Δ, ntuple(_->(:),IN)..., i), length(tup)),)
352352
end
353353

354-
@adjoint function stack(gen::Base.Generator)
354+
function ChainRulesCore.rrule(::typeof(stack), gen::Base.Generator)
355355
stack(gen), Δ -> error("not yet!")
356356
end
357357

0 commit comments

Comments
 (0)