Skip to content

Commit ffe4976

Browse files
committed
Accumulated control flow fixes
1 parent 3106100 commit ffe4976

File tree

6 files changed

+375
-71
lines changed

6 files changed

+375
-71
lines changed

Manifest.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1818

1919
[[ChainRules]]
2020
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
21-
git-tree-sha1 = "948d680b9dfcedcbe58e4d380f3737b4f3e5ac7f"
22-
repo-rev = "kf/diffractorbackport"
23-
repo-url = "https://github.com/Keno/ChainRules.jl.git"
21+
path = "/home/keno/.julia/dev/ChainRules"
2422
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
2523
version = "0.8.18"
2624

src/extra_rules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ function ChainRulesCore.rrule(::typeof(Core.tuple), args...)
145145
Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...)
146146
end
147147

148+
# TODO: What to do about these integer rules
149+
@ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type)
150+
148151
ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()
149152

150153
# Skip AD'ing through the axis computation

src/stage1/generated.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,18 @@ function (::∂⃖rrule{N})(z, z̄) where {N}
189189
y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{minus1(N)}(), ȳ, z̄)
190190
end
191191

192+
function (::∂⃖{N})(f::Core.IntrinsicFunction, args...) where {N}
193+
# A few intrinsic functions are inserted by the compiler, so they need to
194+
# be handled here. Otherwise, we just throw an appropriate error.
195+
if f === Core.Intrinsics.not_int && length(args) == 1
196+
return f(args...), EvenOddOdd{1, c_order(N)}(
197+
Δ->(NoTangent(), NoTangent()),
198+
Δ->NoTangent())
199+
end
200+
201+
error("Rewrite reached intrinsic function $f. Missing rule?")
202+
end
203+
192204
# The static parameter on `f` disables the compileable_sig heuristic
193205
function (::∂⃖{N})(f::T, args...) where {T, N}
194206
if N == 1
@@ -297,7 +309,7 @@ end
297309

298310
struct tuple_back{M}; end
299311
(::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...)
300-
(::tuple_back{N})(Δ::NoTangent) where {N} = Core.tuple(NoTangent(), ntuple(i->NoTangent(), N)...)
312+
(::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...)
301313

302314
function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
303315
Core.tuple(args...),
@@ -333,7 +345,7 @@ end
333345
a.u(r)
334346
end
335347

336-
function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Tuple...) where {N}
348+
function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Union{Tuple, NamedTuple}...) where {N}
337349
@assert iterate === Base.iterate
338350
x, ∂⃖f = Core._apply_iterate(iterate, this, (f,), args...)
339351
return x, ApplyOdd{1, c_order(N)}(UnApply{map(length, args)}(), ∂⃖f)

src/stage1/hacks.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
# Updated copy of the same code in Base, but with bugs fixed
2-
using Core.Compiler: count_added_node!, add!, NewSSAValue
2+
using Core.Compiler: count_added_node!, add!, NewSSAValue, add_pending!,
3+
StmtRange, BasicBlock
34

5+
Base.length(c::Core.Compiler.NewNodeStream) = Core.Compiler.length(c)
46
Base.setindex!(i::Instruction, args...) = Core.Compiler.setindex!(i, args...)
7+
Core.Compiler.BasicBlock(x::UnitRange) =
8+
BasicBlock(StmtRange(first(x), last(x)))
9+
Core.Compiler.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
10+
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
11+
Base.size(x::Core.Compiler.UnitRange) = Core.Compiler.size(x)
12+
513
function my_insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, attach_after::Bool=false)
614
@assert inst.effect_free_computed
715
if isa(before, SSAValue)

0 commit comments

Comments
 (0)