@@ -33,12 +33,12 @@ Base.iterate(f::Foo, state) = iterate(f.a, state)
3333function ChainRulesCore. rrule (:: Type{Foo} , a)
3434 foo = Foo (a)
3535 function Foo_pullback (Δfoo)
36- return NoTangent (), Δfoo. a
36+ return NoTangent (), unthunk ( Δfoo) . a
3737 end
3838 return foo, Foo_pullback
3939end
4040function ChainRulesCore. frule ((_, Δa), :: Type{Foo} , a)
41- return Foo (a), Foo (Δa )
41+ return Foo (a), Foo (unthunk (Δa) )
4242end
4343
4444# functor
@@ -49,7 +49,8 @@ function ChainRulesCore.rrule(f::Foo, x)
4949 end
5050 return y, Foo_pullback
5151end
52- function ChainRulesCore. frule ((Δf, Δx), f:: Foo , x)
52+ function ChainRulesCore. frule ((Δf_, Δx), f:: Foo , x)
53+ Δf = unthunk (Δf_)
5354 return f (x), Δf. a + Δx
5455end
5556
@@ -158,7 +159,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
158159
159160 @testset " check not inferred in frule" begin
160161 function ChainRulesCore. frule ((_, Δx), :: typeof (f_noninferrable_frule), x)
161- return (x, x > 0 ? Float64 (Δx) : Float32 (Δx ))
162+ return (x, x > 0 ? Float64 (unthunk ( Δx)) : Float32 (unthunk (Δx) ))
162163 end
163164 function ChainRulesCore. rrule (:: typeof (f_noninferrable_frule), x)
164165 f_noninferrable_frule_pullback (Δy) = (NoTangent (), Δy)
@@ -205,7 +206,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
205206 @testset " check not inferred in pullback" begin
206207 function ChainRulesCore. rrule (:: typeof (f_noninferrable_pullback), x)
207208 function f_noninferrable_pullback_pullback (Δy)
208- return (NoTangent (), x > 0 ? Float64 (Δy) : Float32 (Δy ))
209+ return (NoTangent (), ( x > 0 ? Float64 : Float32)( unthunk (Δy) ))
209210 end
210211 return x, f_noninferrable_pullback_pullback
211212 end
@@ -219,7 +220,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
219220 @testset " check not inferred in thunk" begin
220221 function ChainRulesCore. rrule (:: typeof (f_noninferrable_thunk), x, y)
221222 function f_noninferrable_thunk_pullback (Δz)
222- ∂x = @thunk (x > 0 ? Float64 (Δz) : Float32 (Δz ))
223+ ∂x = @thunk (x > 0 ? Float64 (unthunk ( Δz)) : Float32 (unthunk (Δz) ))
223224 return (NoTangent (), ∂x, Δz)
224225 end
225226 return x + y, f_noninferrable_thunk_pullback
@@ -233,10 +234,13 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
233234
234235 @testset " check non-inferrable primal still passes if pullback inferrable" begin
235236 function ChainRulesCore. frule ((_, Δx), :: typeof (f_inferrable_pullback_only), x)
236- return (x > 0 ? Float64 (x) : Float32 (x), x > 0 ? Float64 (Δx) : Float32 (Δx))
237+ T = x > 0 ? Float64 : Float32
238+ return T (x), T (unthunk (Δx))
237239 end
238240 function ChainRulesCore. rrule (:: typeof (f_inferrable_pullback_only), x)
239- f_inferrable_pullback_only_pullback (Δy) = (NoTangent (), oftype (x, Δy))
241+ function f_inferrable_pullback_only_pullback (Δy)
242+ return NoTangent (), oftype (x, unthunk (Δy))
243+ end
240244 return x > 0 ? Float64 (x) : Float32 (x), f_inferrable_pullback_only_pullback
241245 end
242246 test_frule (f_inferrable_pullback_only, 2.0 ; check_inferred= true )
@@ -441,7 +445,9 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
441445 ẋ = [4.0 , 5.0 , 6.0 ]
442446 xcopy, ẋcopy = copy (x), copy (ẋ)
443447 y = [1 , 2 ]
444- test_frule (finplace!, x ⊢ ẋ; fkwargs= (y= y,))
448+ # Don't test tangent transforms, we do not support thunks for mutating frules
449+ # TODO : Should we disable testing thunks for frules in general
450+ test_frule (finplace!, x ⊢ ẋ; fkwargs= (y= y,), tangent_transforms= [])
445451 @test x == xcopy
446452 @test ẋ == ẋcopy
447453 @test y == [1 , 2 ]
@@ -462,7 +468,8 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
462468 return s
463469 end
464470
465- function ChainRulesCore. frule ((_, Δiter), :: typeof (iterfun), iter)
471+ function ChainRulesCore. frule ((_, Δiter_), :: typeof (iterfun), iter)
472+ Δiter = unthunk (Δiter_)
466473 iter_Δiter = zip (iter, Δiter)
467474 state = iterate (iter_Δiter)
468475 state === nothing && error ()
0 commit comments