Skip to content

Commit cab3a49

Browse files
oscardssmithKeno
andauthored
fix reverse mode (#175)
* start fixing reverse mode * fix * fix tests --------- Co-authored-by: Keno Fischer <[email protected]>
1 parent 096f918 commit cab3a49

File tree

6 files changed

+91
-96
lines changed

6 files changed

+91
-96
lines changed

Manifest.toml

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
2727

2828
[[deps.ChainRules]]
2929
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
30-
git-tree-sha1 = "8bae903893aeeb429cf732cf1888490b93ecf265"
30+
git-tree-sha1 = "61549d9b52c88df34d21bd306dba1d43bb039c87"
3131
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
32-
version = "1.49.0"
32+
version = "1.51.0"
3333

3434
[[deps.ChainRulesCore]]
3535
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
@@ -50,9 +50,9 @@ version = "1.0.2"
5050

5151
[[deps.Compat]]
5252
deps = ["UUIDs"]
53-
git-tree-sha1 = "7a60c856b9fa189eb34f5f8a6f6b5529b7942957"
53+
git-tree-sha1 = "4e88377ae7ebeaf29a047aa1ee40826e0b708a5d"
5454
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
55-
version = "4.6.1"
55+
version = "4.7.0"
5656
weakdeps = ["Dates", "LinearAlgebra"]
5757

5858
[deps.Compat.extensions]
@@ -61,13 +61,13 @@ weakdeps = ["Dates", "LinearAlgebra"]
6161
[[deps.CompilerSupportLibraries_jll]]
6262
deps = ["Artifacts", "Libdl"]
6363
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
64-
version = "1.0.2+0"
64+
version = "1.0.5+0"
6565

6666
[[deps.Cthulhu]]
6767
deps = ["CodeTracking", "FoldingTrees", "InteractiveUtils", "JuliaSyntax", "PrecompileTools", "Preferences", "REPL", "TypedSyntax", "UUIDs", "Unicode", "WidthLimitedIO"]
68-
git-tree-sha1 = "aac06850ca054d0459ec212aed9788f60dbf79cf"
68+
git-tree-sha1 = "9b804378bbe126f64ca3b4cd4b5dc9e44ea02f70"
6969
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
70-
version = "2.8.15"
70+
version = "2.9.1"
7171

7272
[[deps.DataAPI]]
7373
git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c"
@@ -126,9 +126,9 @@ uuid = "82899510-4779-5014-852e-03e436cf321d"
126126
version = "1.0.0"
127127

128128
[[deps.JuliaSyntax]]
129-
git-tree-sha1 = "3884259b6852ed89c7036c455551a556d8a3a124"
129+
git-tree-sha1 = "3b993680318327a645c0240baf653433a0f09953"
130130
uuid = "70703baa-626e-46a2-a12c-08ffd08c73b4"
131-
version = "0.4.1"
131+
version = "0.4.5"
132132

133133
[[deps.LibGit2]]
134134
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
@@ -143,9 +143,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
143143

144144
[[deps.LogExpFunctions]]
145145
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
146-
git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b"
146+
git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f"
147147
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
148-
version = "0.3.23"
148+
version = "0.3.24"
149149

150150
[deps.LogExpFunctions.extensions]
151151
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
@@ -192,9 +192,9 @@ version = "1.6.0"
192192

193193
[[deps.PrecompileTools]]
194194
deps = ["Preferences"]
195-
git-tree-sha1 = "259e206946c293698122f63e2b513a7c99a244e8"
195+
git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81"
196196
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
197-
version = "1.1.1"
197+
version = "1.1.2"
198198

199199
[[deps.Preferences]]
200200
deps = ["TOML"]
@@ -238,9 +238,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
238238

239239
[[deps.SortingAlgorithms]]
240240
deps = ["DataStructures"]
241-
git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00"
241+
git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee"
242242
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
243-
version = "1.1.0"
243+
version = "1.1.1"
244244

245245
[[deps.SparseArrays]]
246246
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization"]
@@ -249,9 +249,9 @@ version = "1.10.0"
249249

250250
[[deps.StaticArrays]]
251251
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
252-
git-tree-sha1 = "8982b3607a212b070a5e46eea83eb62b4744ae12"
252+
git-tree-sha1 = "832afbae2a45b4ae7e831f86965469a24d1d8a83"
253253
uuid = "90137ffa-7385-5640-81b9-e52037218182"
254-
version = "1.5.25"
254+
version = "1.5.26"
255255

256256
[[deps.StaticArraysCore]]
257257
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
@@ -304,9 +304,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
304304

305305
[[deps.TypedSyntax]]
306306
deps = ["CodeTracking", "JuliaSyntax"]
307-
git-tree-sha1 = "6da6670f978221bea4f501b600f34ec20cb9516e"
307+
git-tree-sha1 = "e38949656d1443d30339d4fc1088fdc49c8f652e"
308308
uuid = "d265eb64-f81a-44ad-a842-4247ee1503de"
309-
version = "1.1.10"
309+
version = "1.2.1"
310310

311311
[[deps.UUIDs]]
312312
deps = ["Random", "SHA"]

src/codegen/reverse.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
3737
push!.edges, bb)
3838
if !isa(stmt.val, SSAValue)
3939
push!.values, insert_node!(ir, i,
40-
non_effect_free(NewInstruction(stmt.val))))
40+
NewInstruction(stmt.val)))
4141
else
4242
push!.values, stmt.val)
4343
end
@@ -78,7 +78,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
7878
for block in cfg.blocks
7979
if length(block.preds) != 0
8080
insert_node!(ir, block.stmts.start,
81-
non_effect_free(NewInstruction(Expr(:phi_placeholder, copy(block.preds)))))
81+
NewInstruction(Expr(:phi_placeholder, copy(block.preds))))
8282
end
8383
end
8484

@@ -614,14 +614,14 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
614614
end
615615
end
616616
tup = terminator_insert_node!(
617-
effect_free(NewInstruction(Expr(:call, tuple, rev[orig_bb_ranges[active_bb]]...), Any, Int32(0))))
617+
effect_free_and_nothrow(NewInstruction(Expr(:call, tuple, rev[orig_bb_ranges[active_bb]]...), Any, Int32(0))))
618618
for succ in succs
619619
preds = cfg.blocks[succ].preds
620620
if length(preds) == 1
621621
val = tup
622622
else
623623
selector = findfirst(==(active_bb), preds)
624-
val = insert_node_here!(compact, effect_free(NewInstruction(Expr(:call, tuple, selector, tup), Any, Int32(0))), true)
624+
val = insert_node_here!(compact, effect_free_and_nothrow(NewInstruction(Expr(:call, tuple, selector, tup), Any, Int32(0))), true)
625625
end
626626
pn = phi_nodes[succ]
627627
push!(pn.edges, active_bb)

src/stage1/recurse.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using Core.Compiler:
44
ReturnNode, SSAValue, SlotNumber, StmtRange,
55
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
66
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
7-
insert_node_here!, non_dce_finish!, quoted, retrieve_code_info,
8-
scan_slot_def_use, userefs
7+
insert_node_here!, effect_free_and_nothrow, non_dce_finish!, quoted, retrieve_code_info,
8+
scan_slot_def_use, userefs, SimpleInferenceLattice
99

1010
using Base.Meta
1111

@@ -279,7 +279,7 @@ function optic_transform!(ci, mi, nargs, N)
279279
domtree = construct_domtree(ir.cfg.blocks)
280280
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
281281
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
282-
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
282+
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, SimpleInferenceLattice.instance)
283283
ir = compact!(ir)
284284

285285
nfixedargs = Int(meth.nargs)

test/regression.jl

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@ const bwd = Diffractor.PrimeDerivativeBack
1313

1414

1515
# Regression tests
16-
@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0] broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
16+
@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0]
1717

1818
function f_broadcast(a)
1919
l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]]
2020
return sum(l)
2121
end
22-
@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
22+
@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0)
2323

2424
# Make sure that there's no infinite recursion in kwarg calls
2525
g_kw(;x=1.0) = sin(x)
2626
f_kw(x) = g_kw(;x)
27-
@test bwd(f_kw)(1.0) == bwd(sin)(1.0) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
27+
@test bwd(f_kw)(1.0) == bwd(sin)(1.0)
2828

2929
function f_crit_edge(a, b, c, x)
3030
# A function with two critical edges. This used to trigger an issue where
@@ -43,98 +43,98 @@ function f_crit_edge(a, b, c, x)
4343

4444
return y
4545
end
46-
@test bwd(x->f_crit_edge(false, false, false, x))(1.0) == 1.0 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
47-
@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
48-
@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
49-
@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
46+
@test bwd(x->f_crit_edge(false, false, false, x))(1.0) == 1.0
47+
@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0
48+
@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0
49+
@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0
5050

5151
# Issue #27 - Mixup in lifting of getfield
5252
let var"'" = bwd
53-
@test (x->x^5)''(1.0) == 20. broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
53+
@test (x->x^5)''(1.0) == 20.
5454
@test_broken (x->x^5)'''(1.0) == 60.
5555
end
5656

5757
# Issue #38 - Splatting arrays
58-
@test gradient(x -> max(x...), (1,2,3))[1] == (0.0, 0.0, 1.0) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
59-
@test gradient(x -> max(x...), [1,2,3])[1] == [0.0, 0.0, 1.0] broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
58+
@test gradient(x -> max(x...), (1,2,3))[1] == (0.0, 0.0, 1.0)
59+
@test gradient(x -> max(x...), [1,2,3])[1] == [0.0, 0.0, 1.0]
6060

6161
# Issue #40 - Symbol type parameters not properly quoted
62-
@test Diffractor.∂⃖recurse{1}()(Val{:transformations})[1] === Val{:transformations}() broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
62+
@test Diffractor.∂⃖recurse{1}()(Val{:transformations})[1] === Val{:transformations}()
6363

6464
# PR #43
6565
loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w)
6666
x43 = rand(10, 10)
67-
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
67+
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}}
6868

6969
# PR # 45 - Calling back into AD from ChainRules
70-
@test_broken y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) # https://github.com/JuliaDiff/Diffractor.jl/issues/170
71-
@test_broken y45 2.0 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
72-
@test_broken back45(1) == (ZeroTangent(), 1.0) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
70+
y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
71+
@test y45 2.0
72+
@test back45(1) == (ZeroTangent(), 1.0)
7373

7474
z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
7575
@test z45 2.0
7676
@test delta45 1.0
7777

7878
# PR #82 - getindex on non-numeric arrays
79-
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1} broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
79+
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}
8080

8181
@testset "broadcast" begin
8282
# derivatives_given_output
83-
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
84-
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
85-
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
83+
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],)
84+
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
85+
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
8686

8787
# frule_via_ad
88-
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
88+
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],)
8989
exp_log(x) = exp(log(x))
90-
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
91-
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
92-
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
90+
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
91+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
92+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
9393
# closure:
94-
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
94+
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,)
9595

9696
# array of arrays
97-
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
98-
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
99-
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
97+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
98+
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
99+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
100100
# must not take fast path
101-
@test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] [12, 12, 12] broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
101+
@test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] [12, 12, 12]
102102

103-
@test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
103+
@test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],)
104104
# x/y rule
105-
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
105+
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,)
106106
# x.^2 rule
107-
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
107+
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],)
108108
# scalar^2 rule
109-
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
109+
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,)
110110

111-
@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
112-
@test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
113-
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
111+
@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),)
112+
@test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),)
113+
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),)
114114

115115
# Bool output
116-
@test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
117-
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
118-
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent()) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
116+
@test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero
117+
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero
118+
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent())
119119
# Bool input
120-
@test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
121-
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
122-
@test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
120+
@test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero
121+
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero
122+
@test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero
123123

124-
@test_broken tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5])) # https://github.com/JuliaDiff/Diffractor.jl/issues/170
125-
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
126-
@test tup_adj[2] [0.6666666666666666 0.5 0.4] broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
127-
@test tup_adj[2] isa Transpose broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
128-
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
124+
tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5]))
125+
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
126+
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
127+
@test tup_adj[2] isa Transpose
128+
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
129129

130130
# closure:
131-
@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
131+
@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,)
132132
end
133133

134134
@testset "broadcast, 2nd order" begin
135135
# calls "split broadcasting generic" with f = unthunk
136-
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
137-
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] broken=true # https://github.com/JuliaDiff/Diffractor.jl/issues/170
136+
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2]
137+
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
138138
# Control flow support not fully implemented yet for higher-order
139139
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12]
140140

@@ -153,4 +153,4 @@ end
153153
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
154154
end
155155

156-
end
156+
end

0 commit comments

Comments
 (0)