Skip to content

Commit e047fd8

Browse files
authored
Merge pull request #396 from JuliaDiff/ox/inplaceorder
Change argument order on InplaceableThunk and fix deprecated tests
2 parents 4b33290 + cd458a1 commit e047fd8

File tree

7 files changed

+39
-25
lines changed

7 files changed

+39
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.10.11"
3+
version = "0.10.12"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/deprecated.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,6 @@ for T in (:Thunk, :InplaceableThunk)
5454
return unthunk(x)
5555
end
5656
end
57+
58+
59+
Base.@deprecate InplaceableThunk(t::Thunk, add!) InplaceableThunk(add!, t)

src/differentials/thunks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a)
197197

198198

199199
"""
200-
InplaceableThunk(val::Thunk, add!::Function)
200+
InplaceableThunk(add!::Function, val::Thunk)
201201
202202
A wrapper for a `Thunk`, that allows it to define an inplace `add!` function.
203203
@@ -209,8 +209,8 @@ Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
209209
and destroy its inplacability.
210210
"""
211211
struct InplaceableThunk{T<:Thunk,F} <: AbstractThunk
212-
val::T
213212
add!::F
213+
val::T
214214
end
215215

216216
unthunk(x::InplaceableThunk) = unthunk(x.val)

test/accumulation.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989

9090
@testset "AbstractThunk $(typeof(thunk))" for thunk in (
9191
@thunk(-1.0*ones(2, 2)),
92-
InplaceableThunk(@thunk(-1.0*ones(2, 2)), x -> x .-= ones(2, 2)),
92+
InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0*ones(2, 2))),
9393
)
9494
@testset "in place" begin
9595
accumuland = [1.0 2.0; 3.0 4.0]
@@ -109,10 +109,10 @@
109109
end
110110

111111
@testset "not actually inplace but said it was" begin
112-
ithunk = InplaceableThunk(
113-
@thunk(@assert false), # this should never be used in this test
114-
x -> 77*ones(2, 2) # not actually inplace (also wrong)
115-
)
112+
# thunk should never be used in this test
113+
ithunk = InplaceableThunk(@thunk(@assert false)) do x
114+
77*ones(2, 2) # not actually inplace (also wrong)
115+
end
116116
accumuland = ones(2, 2)
117117
@assert ChainRulesCore.debug_mode() == false
118118
# without debug being enabled should return the result, not error
@@ -127,7 +127,7 @@
127127

128128
@testset "showerror BadInplaceException" begin
129129
BadInplaceException = ChainRulesCore.BadInplaceException
130-
ithunk = InplaceableThunk(@thunk(@assert false), x̄->nothing)
130+
ithunk = InplaceableThunk(->nothing, @thunk(@assert false))
131131
msg = sprint(showerror, BadInplaceException(ithunk, [22], [23]))
132132
@test occursin("22", msg)
133133

test/deprecated.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,39 @@
11
@testset "NO_FIELDS" begin
2-
@test (@test_deprecated NO_FIELDS) isa NoTangent
2+
# Following doesn't work because of some deprecate_binding weirdness with not printing
3+
# @test (@test_deprecated NO_FIELDS) isa NoTangent
4+
# So just test it gives the old behavour
5+
@test NO_FIELDS isa NoTangent
36
end
47

58
@testset "extern" begin
6-
@test extern(@thunk(3)) == 3
7-
@test extern(@thunk(@thunk(3))) == 3
9+
@test (@test_deprecated extern(@thunk(3))) == 3
10+
@test (@test_deprecated extern(@thunk(@thunk(3)))) == 3
811

9-
@test extern(Tangent{Foo}(x=2.0)) == (;x=2.0)
10-
@test extern(Tangent{Tuple{Float64,}}(2.0)) == (2.0,)
11-
@test extern(Tangent{Dict}(Dict(4 => 3))) == Dict(4 => 3)
12+
@test (@test_deprecated extern(Tangent{Foo}(x=2.0))) == (;x=2.0)
13+
@test (@test_deprecated extern(Tangent{Tuple{Float64,}}(2.0))) == (2.0,)
14+
@test (@test_deprecated extern(Tangent{Dict}(Dict(4 => 3)))) == Dict(4 => 3)
1215

1316
# with differentials on the inside
14-
@test extern(Tangent{Foo}(x=@thunk(0+2.0))) == (;x=2.0)
15-
@test extern(Tangent{Tuple{Float64,}}(@thunk(0+2.0))) == (2.0,)
16-
@test extern(Tangent{Dict}(Dict(4 => @thunk(3)))) == Dict(4 => 3)
17+
@test (@test_deprecated extern(Tangent{Foo}(x=@thunk(0+2.0)))) == (;x=2.0)
18+
@test (@test_deprecated extern(Tangent{Tuple{Float64,}}(@thunk(0+2.0)))) == (2.0,)
19+
@test (@test_deprecated extern(Tangent{Dict}(Dict(4 => @thunk(3))))) == Dict(4 => 3)
1720

1821
z = ZeroTangent()
19-
@test extern(z) === false
22+
@test (@test_deprecated extern(z)) === false
23+
24+
# @test_throws doesn't play nice with `@test_deprecated` so have to be loud
2025
dne = NoTangent()
2126
@test_throws Exception extern(dne)
22-
E = ChainRulesCore.NotImplementedException
23-
@test_throws E extern(ni)
27+
ni = @not_implemented("no")
28+
@test_throws ChainRulesCore.NotImplementedException extern(ni)
2429
end
2530

2631

2732
@testset "Deprecated: calling thunks should call inner function" begin
28-
@test_deprecated (@thunk(3))() == 3
29-
@test_deprecated (@thunk(@thunk(3)))() isa Thunk
33+
@test (@test_deprecated (@thunk(3))()) == 3
34+
@test (@test_deprecated (@thunk(@thunk(3)))()) isa Thunk
3035
end
36+
37+
@testset "Deprecated: Inplacable Thunk argument order" begin
38+
@test (@test_deprecated InplaceableThunk(@thunk([1]), x->x.+=1)) isa InplaceableThunk
39+
end

test/differentials/thunks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
@test @thunk(3) isa Thunk
55

66
@testset "==" begin
7-
@test @thunk(3.2) == InplaceableThunk(@thunk(3.2), x -> x + 3.2)
7+
@test @thunk(3.2) == InplaceableThunk(x -> x + 3.2, @thunk(3.2))
88
@test @thunk(3.2) == 3.2
9-
@test 3.2 == InplaceableThunk(@thunk(3.2), x -> x + 3.2)
9+
@test 3.2 == InplaceableThunk(x -> x + 3.2, @thunk(3.2))
1010
end
1111

1212
@testset "iterate" begin

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,6 @@ using Test
2121
include("rules.jl")
2222
include("rule_definition_tools.jl")
2323
include("config.jl")
24+
25+
include("deprecated.jl")
2426
end

0 commit comments

Comments
 (0)