@@ -598,6 +598,35 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
598598 test_rrule (rev_trouble, (3 , 3.0 ) ⊢ Tangent {Tuple{Int,Float64}} (ZeroTangent (), 1.0 ))
599599 end
600600
601+ @testset " check_thunked_output_tangent" begin
602+ @testset " no method for thunk" begin
603+ does_not_accept_thunk_id (x) = x
604+ function ChainRulesCore. rrule (:: typeof (does_not_accept_thunk_id), x)
605+ does_not_accept_thunk_id_pullback (ȳ:: AbstractArray ) = (NoTangent () ,ȳ)
606+ return does_not_accept_thunk_id (x), does_not_accept_thunk_id_pullback
607+ end
608+
609+ test_rrule (
610+ does_not_accept_thunk_id, [1.0 , 2.0 ]; check_thunked_output_tangent= false
611+ )
612+ @test errors (r" MethodError.*Thunk" ) do
613+ test_rrule (does_not_accept_thunk_id, [1.0 , 2.0 ])
614+ end
615+ end
616+
617+ @testset " Thunk wrong" begin
618+ bad_thunk_id (x) = x
619+ function ChainRulesCore. rrule (:: typeof (bad_thunk_id), x)
620+ bad_thunk_id_pullback (ȳ:: AbstractArray ) = (NoTangent (), ȳ)
621+ bad_thunk_id_pullback (ȳ:: AbstractThunk ) = (NoTangent (), 2 * ȳ)
622+ return bad_thunk_id (x), bad_thunk_id_pullback
623+ end
624+
625+ test_rrule (bad_thunk_id, [1.0 , 2.0 ]; check_thunked_output_tangent= false )
626+ @test fails (()-> test_rrule (bad_thunk_id, [1.0 , 2.0 ]))
627+ end
628+ end
629+
601630 @testset " error message about incorrectly using ZeroTangent()" begin
602631 foo (a, i) = a[i]
603632 function ChainRulesCore. rrule (:: typeof (foo), a, i)
0 commit comments