Skip to content

Commit d9fc3bf

Browse files
authored
Merge pull request #494 from MilesCranmer/fix-mlj-option-caching
fix: caching of options in MLJ regressors
2 parents 169f973 + 6dcaedb commit d9fc3bf

File tree

6 files changed

+88
-5
lines changed

6 files changed

+88
-5
lines changed

src/Core.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ using .OptionsStructModule:
3030
Options,
3131
ComplexityMapping,
3232
specialized_options,
33-
operator_specialization
33+
operator_specialization,
34+
WarmStartIncompatibleError,
35+
check_warm_start_compatibility
3436
using .OperatorsModule:
3537
get_safe_op,
3638
plus,

src/MLJInterface.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ using ..CoreModule:
3737
ComplexityMapping,
3838
AbstractExpressionSpec,
3939
ExpressionSpec,
40-
get_expression_type
40+
get_expression_type,
41+
check_warm_start_compatibility
4142
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
4243
using ..ComplexityModule: compute_complexity
4344
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
@@ -232,7 +233,10 @@ function MMI.update(
232233
y,
233234
w=nothing,
234235
)
235-
options = old_fitresult === nothing ? get_options(m) : old_fitresult.options
236+
options = get_options(m)
237+
if !isnothing(old_fitresult)
238+
check_warm_start_compatibility(old_fitresult.options, options)
239+
end
236240
return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing)
237241
end
238242
function _update(

src/Options.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ $(OPTION_DESCRIPTIONS)
10381038
}(
10391039
operators,
10401040
op_constraints,
1041+
_nested_constraints,
10411042
complexity_mapping,
10421043
tournament_selection_n,
10431044
tournament_selection_p,
@@ -1099,7 +1100,6 @@ $(OPTION_DESCRIPTIONS)
10991100
max_evals,
11001101
input_stream,
11011102
skip_mutation_failures,
1102-
_nested_constraints,
11031103
deterministic,
11041104
define_helper_functions,
11051105
use_recorder,

src/OptionsStruct.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ end
2828

2929
Base.eltype(::ComplexityMapping{T}) where {T} = T
3030

31+
function Base.:(==)(a::ComplexityMapping, b::ComplexityMapping)
32+
return a.use == b.use &&
33+
a.op_complexities == b.op_complexities &&
34+
a.variable_complexity == b.variable_complexity &&
35+
a.constant_complexity == b.constant_complexity
36+
end
37+
3138
"""Promote type when defining complexity mapping."""
3239
function ComplexityMapping(;
3340
op_complexities::Tuple{Vararg{Vector,D}},
@@ -182,6 +189,7 @@ struct Options{
182189
} <: AbstractOptions
183190
operators::OP
184191
op_constraints::OP_CONSTRAINTS
192+
nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
185193
complexity_mapping::CM
186194
tournament_selection_n::Int
187195
tournament_selection_p::Float32
@@ -243,7 +251,6 @@ struct Options{
243251
max_evals::Union{Int,Nothing}
244252
input_stream::IO
245253
skip_mutation_failures::Bool
246-
nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
247254
deterministic::Bool
248255
define_helper_functions::Bool
249256
use_recorder::Bool
@@ -290,4 +297,40 @@ end
290297
end
291298
end
292299

300+
struct WarmStartIncompatibleError <: Exception
301+
fields::Vector{Symbol}
302+
end
303+
304+
function Base.showerror(io::IO, e::WarmStartIncompatibleError)
305+
print(io, "Warm start incompatible due to changed field(s): ")
306+
join(io, e.fields, ", ")
307+
return print(io, ". Use `fit!(mach, force=true)` to restart training.")
308+
end
309+
310+
check_warm_start_compatibility(::AbstractOptions, ::AbstractOptions) = nothing # LCOV_EXCL_LINE
311+
312+
function check_warm_start_compatibility(old_options::Options, new_options::Options)
313+
incompatible_fields = (
314+
:operators,
315+
:op_constraints,
316+
:nested_constraints,
317+
:complexity_mapping,
318+
:dimensionless_constants_only,
319+
:maxsize,
320+
:maxdepth,
321+
:populations,
322+
:population_size,
323+
:node_type,
324+
:expression_type,
325+
:expression_options,
326+
)
327+
328+
changed = [
329+
f for f in incompatible_fields if
330+
getproperty(old_options, f) != getproperty(new_options, f)
331+
]
332+
isempty(changed) || throw(WarmStartIncompatibleError(changed))
333+
return nothing
334+
end
335+
293336
end

src/SymbolicRegression.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ using .CoreModule:
243243
AbstractOptions,
244244
Options,
245245
ComplexityMapping,
246+
WarmStartIncompatibleError,
246247
AbstractMutationWeights,
247248
MutationWeights,
248249
AbstractExpressionSpec,

test/test_mlj.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,36 @@ end
257257
end
258258
@test occursin("Evaluation failed either due to", msg)
259259
end
260+
261+
@testitem "MLJ options caching fix" tags = [:part3] begin
262+
using SymbolicRegression
263+
using SymbolicRegression: WarmStartIncompatibleError
264+
using MLJBase
265+
using Random: MersenneTwister
266+
using Suppressor
267+
268+
include("test_params.jl")
269+
270+
# Test that parameter changes are respected and incompatible changes throw errors
271+
rng = MersenneTwister(0)
272+
X = (x1=randn(rng, 50), x2=randn(rng, 50))
273+
y = @. 2.0 * X.x1 + 3.0 * X.x2
274+
275+
model = SRRegressor(;
276+
binary_operators=[+, -, *], niterations=2, tournament_selection_n=10, populations=2
277+
)
278+
279+
mach = machine(model, X, y)
280+
@suppress fit!(mach, verbosity=0)
281+
282+
# Test compatible parameter change
283+
model.tournament_selection_n = 20
284+
@suppress fit!(mach, verbosity=0)
285+
@test mach.fitresult.options.tournament_selection_n == 20 # Should be updated
286+
287+
# Test incompatible parameter change throws error with correct message
288+
model.populations = 4
289+
err = @test_throws WarmStartIncompatibleError @suppress fit!(mach, verbosity=0)
290+
@test :populations err.value.fields
291+
@test occursin("force=true", sprint(showerror, err.value))
292+
end

0 commit comments

Comments
 (0)