285285# workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484
286286module IsolatedModuleForTestingScoping
287287 # check that rules can be defined by macros without any additional imports
288- using ChainRulesCore: @scalar_rule , @non_differentiable
288+ using ChainRulesCore: @scalar_rule , @non_differentiable , @opt_out
289289
290290 # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved
291291 const ChainRulesCore = nothing
@@ -303,11 +303,20 @@ module IsolatedModuleForTestingScoping
303303 my_id (x) = x
304304 @scalar_rule (my_id (x), 1.0 )
305305
306+ # @opt_out
307+ first_oa (x, y) = x
308+ @scalar_rule (first_oa (x, y), (1 , 0 ))
309+ # Declared without using the ChainRulesCore namespace qualification
310+ # see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/545
311+ @opt_out rrule (:: typeof (first_oa), x:: T , y:: T ) where {T<: Float16 }
312+ @opt_out frule (:: Any , :: typeof (first_oa), x:: T , y:: T ) where {T<: Float16 }
313+
306314 module IsolatedSubmodule
307315 # check that rules defined in isolated module without imports can be called
308316 # without errors
309317 using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output
310- using .. IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id
318+ using ChainRulesCore: no_rrule, no_frule
319+ using .. IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id, first_oa
311320 using Test
312321
313322 @testset " @non_differentiable" begin
@@ -339,6 +348,25 @@ module IsolatedModuleForTestingScoping
339348
340349 @test derivatives_given_output (y, my_id, x) == ((1.0 ,),)
341350 end
351+
352+ @testset " @optout" begin
353+ # rrule
354+ @test rrule (first_oa, Float16 (3.0 ), Float16 (4.0 )) === nothing
355+ @test ! isempty (
356+ Iterators. filter (methods (no_rrule)) do m
357+ m. sig <: Tuple{Any,typeof(first_oa),T,T} where {T<: Float16 }
358+ end ,
359+ )
360+
361+ # frule
362+ @test frule ((NoTangent (), 1 , 0 ), first_oa, Float16 (3.0 ), Float16 (4.0 )) ===
363+ nothing
364+ @test ! isempty (
365+ Iterators. filter (methods (no_frule)) do m
366+ m. sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<: Float16 }
367+ end ,
368+ )
369+ end
342370 end
343371end
344372# ! format: on
0 commit comments