diff --git a/test/test_utils.jl b/test/test_utils.jl index 945c439ce..b2172a20a 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -104,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) @inferred f(args...) @inferred Zygote._pullback(ctx, f, args...) out, pb = Zygote._pullback(ctx, f, args...) - @inferred pb(out) + @inferred collect(pb(out)) end function test_ADs(