Skip to content

Commit 2c225aa

Browse files
Fix CI: ChainRules test + JET 0.11 matcher (#153)
1 parent d4b12df commit 2c225aa

File tree

9 files changed

+124
-39
lines changed

9 files changed

+124
-39
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ Reexport = "1"
4242
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
4343
Zygote = "0.7"
4444
julia = "1.10"
45+
Random = "1"
46+
TOML = "1"
4547

4648
[extras]
4749
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"

src/ExpressionAlgebra.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function Base.showerror(io::IO, e::MissingOperatorError)
3535
end
3636

3737
"""
38-
declare_operator_alias(op::Function, ::Val{arity})::Function
38+
declare_operator_alias(op, ::Val{arity})
3939
4040
Define how an internal operator should be matched against user-provided operators in expression trees.
4141
@@ -52,7 +52,7 @@ Which would allow a user to write `sqrt(x::Expression)`
5252
and have it match the operator `safe_sqrt` stored in the binary operators
5353
of the expression.
5454
"""
55-
declare_operator_alias(op::F, _) where {F<:Function} = op
55+
declare_operator_alias(op::F, _) where {F} = op
5656

5757
allow_chaining(@nospecialize(op)) = false
5858
allow_chaining(::typeof(+)) = true

src/ValueInterface.jl

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ Base.@propagate_inbounds function pack_scalar_constants!(
2424
return idx + 1
2525
end
2626

27+
# Fallback so callers (and static analysis tools like JET) don't see a MethodError.
28+
# Types that want to participate in the ValueInterface must implement a more specific method.
29+
function pack_scalar_constants!(::AbstractVector{<:Number}, ::Int64, value)
30+
throw(ArgumentError("pack_scalar_constants! not implemented for $(typeof(value))"))
31+
end
32+
2733
"""
2834
unpack_scalar_constants(nvals, idx, value)
2935
@@ -37,10 +43,14 @@ Returns a tuple of the next index to read from, and the filled-in value.
3743
"""
3844
Base.@propagate_inbounds function unpack_scalar_constants(
3945
nvals::AbstractVector{<:Number}, idx::Int64, value::T
40-
) where {T}
46+
) where {T<:Number}
4147
return (idx + 1, convert(T, nvals[idx]))
4248
end
4349

50+
function unpack_scalar_constants(::AbstractVector{<:Number}, ::Int64, value)
51+
throw(ArgumentError("unpack_scalar_constants not implemented for $(typeof(value))"))
52+
end
53+
4454
"""
4555
count_scalar_constants(value)
4656
@@ -60,24 +70,57 @@ end
6070
function _check_is_valid_array(x)
6171
return is_valid_array([x]) isa Bool && is_valid_array([x]) == is_valid(x)
6272
end
63-
function _check_get_number_type(x)
73+
function _check_get_number_type(x::X) where {X}
6474
try
65-
get_number_type(typeof(x)) <: Number
75+
get_number_type(X) <: Number
6676
catch e
6777
@error e
6878
return false
6979
end
7080
end
71-
function _check_pack_scalar_constants!(x)
72-
packed_x = Vector{get_number_type(typeof(x))}(undef, count_scalar_constants(x))
81+
function _check_pack_scalar_constants!(x::X) where {X}
82+
if !applicable(count_scalar_constants, x)
83+
return false
84+
end
85+
n = count_scalar_constants(x)
86+
87+
packed_x = if X <: Number
88+
Vector{X}(undef, n)
89+
else
90+
# For non-`Number` values, we can't assume a concrete scalar type here.
91+
# Use a generic numeric buffer; correctness is checked by roundtripping.
92+
Vector{Float64}(undef, n)
93+
end
94+
95+
if !applicable(pack_scalar_constants!, packed_x, 1, x)
96+
return false
97+
end
98+
7399
new_idx = pack_scalar_constants!(packed_x, 1, x)
74-
return new_idx == 1 + count_scalar_constants(x)
100+
return new_idx == 1 + n
75101
end
76-
function _check_unpack_scalar_constants(x)
77-
packed_x = Vector{get_number_type(typeof(x))}(undef, count_scalar_constants(x))
102+
function _check_unpack_scalar_constants(x::X) where {X}
103+
if !applicable(count_scalar_constants, x)
104+
return false
105+
end
106+
n = count_scalar_constants(x)
107+
108+
packed_x = if X <: Number
109+
Vector{X}(undef, n)
110+
else
111+
Vector{Float64}(undef, n)
112+
end
113+
114+
if !applicable(pack_scalar_constants!, packed_x, 1, x)
115+
return false
116+
end
117+
if !applicable(unpack_scalar_constants, packed_x, 1, x)
118+
return false
119+
end
120+
78121
pack_scalar_constants!(packed_x, 1, x)
79122
new_idx, x2 = unpack_scalar_constants(packed_x, 1, x)
80-
return new_idx == 1 + count_scalar_constants(x) && x2 == x
123+
return new_idx == 1 + n && x2 == x
81124
end
82125
function _check_count_scalar_constants(x)
83126
return count_scalar_constants(x) isa Int &&

src/precompile.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types
9090
end
9191

9292
function test_functions_on_trees(::Type{T}, operators) where {T}
93-
local x, c, tree
93+
local x, c
94+
tree = Node(T; feature=1)
9495
num_unaops = length(operators.unaops)
9596
num_binops = length(operators.binops)
9697
@assert num_unaops > 0 && num_binops > 0

test/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
66
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
77
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9-
Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
109
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
10+
Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1313
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
@@ -25,7 +25,8 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
2525
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2626

2727
[compat]
28-
Aqua = "0.7"
28+
Aqua = "0.8"
29+
JET = "0.9, 0.10"
2930

3031
[preferences.DynamicExpressions]
3132
dispatch_doctor_codegen_level = "min"

test/runtests.jl

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,42 @@ if "jet" in test_name
2323
)
2424
using JET
2525
using DynamicExpressions
26-
struct MyIgnoredModule
27-
mod::Module
28-
end
29-
function JET.match_module(
30-
mod::MyIgnoredModule, @nospecialize(report::JET.InferenceErrorReport)
31-
)
32-
s_mod = string(mod.mod)
33-
any(report.vst) do vst
34-
occursin(s_mod, string(JET.linfomod(vst.linfo)))
26+
27+
ignored_mod = DynamicExpressions.NonDifferentiableDeclarationsModule
28+
29+
if isdefined(JET, :match_report) &&
30+
isdefined(JET, :ReportMatcher) &&
31+
isdefined(JET, :AnyFrameModule)
32+
# JET >= 0.11
33+
JET.test_package(
34+
DynamicExpressions;
35+
target_defined_modules=true,
36+
ignored_modules=(JET.AnyFrameModule(ignored_mod),),
37+
)
38+
else
39+
# JET <= 0.10: old matcher API
40+
struct MyIgnoredModule
41+
mod::Module
42+
end
43+
function JET.match_module(
44+
mod::MyIgnoredModule, @nospecialize(report::JET.InferenceErrorReport)
45+
)
46+
s_mod = string(mod.mod)
47+
any(report.vst) do vst
48+
occursin(s_mod, string(JET.linfomod(vst.linfo)))
49+
end
3550
end
51+
# On JET 0.10, `target_defined_modules` is not available and also
52+
# can cause spurious possible-error reports when analyzing beyond
53+
# the package's own modules. Restrict to the DynamicExpressions module.
54+
JET.test_package(
55+
DynamicExpressions;
56+
target_modules=(DynamicExpressions,),
57+
ignored_modules=(MyIgnoredModule(ignored_mod),),
58+
)
59+
# TODO: Hack to get JET to ignore modules
60+
# https://github.com/aviatesk/JET.jl/issues/570#issuecomment-2199167755
3661
end
37-
JET.test_package(
38-
DynamicExpressions;
39-
target_defined_modules=true,
40-
ignored_modules=(
41-
MyIgnoredModule(DynamicExpressions.NonDifferentiableDeclarationsModule),
42-
),
43-
)
44-
# TODO: Hack to get JET to ignore modules
45-
# https://github.com/aviatesk/JET.jl/issues/570#issuecomment-2199167755
4662
end
4763
end
4864
if "main" in test_name

test/test_aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
using DynamicExpressions
22
using Aqua
33

4-
Aqua.test_all(DynamicExpressions; project_toml_formatting=false)
4+
Aqua.test_all(DynamicExpressions)

test/test_chainrules.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ let
8383
end
8484

8585
# Operator that is NaN for forward pass
86-
bad_op(x) = x > 0.0 ? log(x) : convert(typeof(x), NaN)
86+
# Define only for numeric types; `@extend_operators` adds the `Node` method.
87+
bad_op(x::Real) = x > 0.0 ? log(x) : convert(typeof(x), NaN)
8788
# And operator that is undefined for backward pass
88-
undefined_grad_op(x) = x >= 0.0 ? x : zero(x)
89+
undefined_grad_op(x::Real) = x >= 0.0 ? x : zero(x)
8990
# And operator that gives a NaN for backward pass
9091
bad_grad_op(x) = x
9192

@@ -102,9 +103,11 @@ let
102103
@extend_operators operators
103104
x1 = Node(Float64; feature=1)
104105

105-
nan_forward = bad_op(x1 + 0.5)
106-
undefined_grad = undefined_grad_op(x1 + 0.5)
107-
nan_grad = bad_grad_op(x1)
106+
# `@extend_operators` defines methods via `eval`, which can trigger world-age issues
107+
# when tests are executed inside a function (Julia >= 1.12). Use `invokelatest`.
108+
nan_forward = Base.invokelatest(bad_op, x1 + 0.5)
109+
undefined_grad = Base.invokelatest(undefined_grad_op, x1 + 0.5)
110+
nan_grad = Base.invokelatest(bad_grad_op, x1)
108111

109112
function eval_tree(X, tree)
110113
y, _ = eval_tree_array(tree, X, operators)

test/test_value_interface.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ end
77

88
@testitem "ValueInterface generic helpers" begin
99
using DynamicExpressions.ValueInterfaceModule:
10-
is_valid, count_scalar_constants, get_number_type
10+
is_valid,
11+
count_scalar_constants,
12+
get_number_type,
13+
pack_scalar_constants!,
14+
unpack_scalar_constants
1115
using Test
1216

1317
# generic `is_valid` (non-number) falls back to `true`
@@ -20,4 +24,19 @@ end
2024
# simple scalar utilities
2125
@test count_scalar_constants(42) == 1
2226
@test get_number_type(Float32) == Float32
27+
28+
# error paths: fallback methods should throw a helpful error
29+
@test_throws ErrorException get_number_type(String)
30+
@test_throws ArgumentError pack_scalar_constants!(Float64[], 1, "hello")
31+
@test_throws ArgumentError unpack_scalar_constants(Float64[1.0], 1, "hello")
32+
33+
# internal interface self-checks: verify we handle bad types gracefully
34+
@test_logs (:error,) begin
35+
@test DynamicExpressions.ValueInterfaceModule._check_get_number_type("hello") ==
36+
false
37+
end
38+
@test DynamicExpressions.ValueInterfaceModule._check_pack_scalar_constants!("hello") ==
39+
false
40+
@test DynamicExpressions.ValueInterfaceModule._check_unpack_scalar_constants("hello") ==
41+
false
2342
end

0 commit comments

Comments
 (0)