Skip to content

Commit 8fae11b

Browse files
authored
Merge pull request #278 from DhairyaLGandhi/zyg
Add adjoint for `create_array`
2 parents f55b96e + a43186d commit 8fae11b

File tree

5 files changed

+48
-1
lines changed

5 files changed

+48
-1
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.11.0"
66
[deps]
77
AbstractAlgebra = "c3fe647b-3220-5bb0-a1ea-a7954cac585d"
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
9+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1011
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1112
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -22,6 +23,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2223
[compat]
2324
AbstractAlgebra = "0.9, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15"
2425
AbstractTrees = "0.3"
26+
ChainRulesCore = "0.9"
2527
Combinatorics = "1.0"
2628
ConstructionBase = "1.1"
2729
DataStructures = "0.18"
@@ -40,6 +42,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4042
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
4143
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4244
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
45+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4346

4447
[targets]
45-
test = ["Test", "Random", "PkgBenchmark", "BenchmarkTools", "Pkg"]
48+
test = ["Test", "Random", "PkgBenchmark", "BenchmarkTools", "Pkg", "Zygote"]

src/SymbolicUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,7 @@ include("api.jl")
5353

5454
include("code.jl")
5555

56+
# ADjoints
57+
include("adjoints.jl")
58+
5659
end # module

src/adjoints.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using ChainRulesCore
2+
import ChainRulesCore: rrule
3+
import .Code
4+
5+
function rrule(::typeof(Code.create_array), A::Type{<:AbstractArray}, T, u::Val{j}, d::Val{dims}, elems...) where {dims, j}
6+
y = Code.create_array(A, T, u, d, elems...)
7+
function create_array_pullback(Δ)
8+
dx = Δ
9+
(NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dx..., ntuple(_ -> DoesNotExist(), length(elems) - prod(dims) + j)...)
10+
end
11+
y, create_array_pullback
12+
end

test/adjoints.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using SymbolicUtils, SymbolicUtils.Code
2+
using Zygote
3+
4+
@testset "create_array adjoint" begin
5+
elems = (1,2,3,4,5,)
6+
7+
Ts = (Float64, Float32, Float16, Int64, Int32)
8+
dims_candidates = ((1, (2,3)), (2, (1,3)))
9+
As = (Array,)
10+
11+
for T in Ts,
12+
dims in dims_candidates,
13+
A in As
14+
15+
u, dim = dims
16+
ŷ, pb = Zygote.pullback(elems) do elems
17+
SymbolicUtils.Code.create_array(A, T, Val(u), Val(dim), elems...)
18+
end
19+
y = SymbolicUtils.Code.create_array(A, T, Val(u), Val(dim), elems...)
20+
@test y ==
21+
22+
gs = pb(ones(T, length(elems)))
23+
@test length(gs[1]) == length(elems)
24+
for i = 1:(prod(dim)-1)
25+
@test gs[1][i] == one(eltype(ŷ))
26+
end
27+
end
28+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ include("code.jl")
2121
include("nf.jl")
2222
include("interface.jl")
2323
include("fuzz.jl")
24+
include("adjoints.jl")

0 commit comments

Comments
 (0)