Skip to content

Commit d05f559

Browse files
authored
Merge pull request #484 from MilesCranmer/autoconvert_composable
feat: add Number input support to ComposableExpression
2 parents b87426d + a7b9123 commit d05f559

File tree

5 files changed

+218
-8
lines changed

5 files changed

+218
-8
lines changed

src/ComposableExpression.jl

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,42 @@ function (ex::AbstractComposableExpression)(
187187
return x .* nan
188188
end
189189
end
190+
# Method for all-Number arguments (scalars)
191+
function (ex::AbstractComposableExpression)(x::Number, _xs::Vararg{Number,N}) where {N}
192+
xs = (x, _xs...)
193+
194+
vectors = ntuple(i -> ValidVector([float(xs[i])], true), length(xs))
195+
return only(_get_value(ex(vectors...)))
196+
end
197+
190198
function (ex::AbstractComposableExpression)(
191-
x::ValidVector, _xs::Vararg{ValidVector,N}
199+
x::Union{ValidVector,Number}, _xs::Vararg{Union{ValidVector,Number},N}
192200
) where {N}
193201
xs = (x, _xs...)
194-
valid = all(_is_valid, xs)
195-
if !valid
196-
return ValidVector(_get_value(first(xs)), false)
197-
else
198-
X = Matrix(stack(map(_get_value, xs))')
202+
sample_vector =
203+
let first_valid_vector_idx = findfirst(arg -> arg isa ValidVector, xs)::Int
204+
xs[first_valid_vector_idx]::ValidVector
205+
end
206+
207+
# Convert Numbers to ValidVectors based on first ValidVector's size
208+
valid_args = ntuple(length(xs)) do i
209+
arg = xs[i]
210+
if arg isa ValidVector
211+
arg
212+
else
213+
# Convert Number to ValidVector with repeated values
214+
filled_array = similar(sample_vector.x)
215+
fill!(filled_array, arg)
216+
ValidVector(filled_array, true)
217+
end
218+
end
219+
220+
if all(_is_valid, valid_args)
221+
X = stack(map(_get_value, valid_args); dims=1)
199222
eval_options = get_eval_options(ex)
200223
return ValidVector(eval_tree_array(ex, X; eval_options))
224+
else
225+
return ValidVector(_get_value(first(valid_args)), false)
201226
end
202227
end
203228
function (ex::AbstractComposableExpression{T})() where {T}
@@ -252,6 +277,55 @@ _is_valid(x) = true
252277
_get_value(x::ValidVector) = x.x
253278
_get_value(x) = x
254279

280+
struct ValidVectorMixError <: Exception end
281+
struct ValidVectorAccessError <: Exception end
282+
283+
function Base.showerror(io::IO, ::ValidVectorMixError)
284+
return print(
285+
io,
286+
"""
287+
ValidVectorMixError: Cannot mix ValidVector with regular Vector.
288+
289+
ValidVector handles validity checks, auto-vectorization, and batching in template expressions.
290+
The .valid field tracks whether any upstream computation failed (false = failed, true = valid).
291+
292+
Wrap your vectors in ValidVector:
293+
294+
```julia
295+
valid_ar1 = ValidVector(ar1, all(isfinite, ar1))
296+
valid_ar1 + valid_ar2
297+
```
298+
299+
Alternatively, you can access the vector from a ValidVector with `my_validvector.x`,
300+
but you must be sure to propagate the `.valid` field. For example:
301+
302+
```julia
303+
out = ar1 .+ valid_ar2.x
304+
ValidVector(out, all(isfinite, out) && valid_ar2.valid)
305+
```
306+
307+
""",
308+
)
309+
end
310+
311+
function Base.showerror(io::IO, ::ValidVectorAccessError)
312+
return print(
313+
io,
314+
"""
315+
ValidVectorAccessError: ValidVector doesn't support direct array operations.
316+
317+
Use .x for data and .valid for validity:
318+
319+
```julia
320+
valid_ar.x[1] # indexing
321+
length(valid_ar.x) # length
322+
valid_ar.valid # check validity (false = any upstream computation failed)
323+
```
324+
325+
ValidVector handles validity/batching automatically in template expressions.""",
326+
)
327+
end
328+
255329
#! format: off
256330
# First, binary operators:
257331
for op in (
@@ -264,6 +338,9 @@ for op in (
264338
Base.$(op)(x::ValidVector, y::ValidVector) = apply_operator(Base.$(op), x, y)
265339
Base.$(op)(x::ValidVector, y::Number) = apply_operator(Base.$(op), x, y)
266340
Base.$(op)(x::Number, y::ValidVector) = apply_operator(Base.$(op), x, y)
341+
342+
Base.$(op)(::ValidVector, ::AbstractVector) = throw(ValidVectorMixError())
343+
Base.$(op)(::AbstractVector, ::ValidVector) = throw(ValidVectorMixError())
267344
end
268345
end
269346
function Base.literal_pow(::typeof(^), x::ValidVector, ::Val{p}) where {p}
@@ -286,6 +363,12 @@ for op in (
286363
end
287364
#! format: on
288365

366+
Base.length(::ValidVector) = throw(ValidVectorAccessError())
367+
Base.push!(::ValidVector, ::Any) = throw(ValidVectorAccessError())
368+
for op in (:getindex, :size, :append!, :setindex!)
369+
@eval Base.$(op)(::ValidVector, ::Any...) = throw(ValidVectorAccessError())
370+
end
371+
289372
# TODO: Support for 3-ary operators
290373

291374
end

src/SymbolicRegression.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,9 @@ using .SearchUtilsModule:
341341
using .LoggingModule: AbstractSRLogger, SRLogger, get_logger
342342
using .TemplateExpressionModule:
343343
TemplateExpression, TemplateStructure, TemplateExpressionSpec, ParamVector, has_params
344-
using .TemplateExpressionModule: ValidVector
345-
using .ComposableExpressionModule: ComposableExpression
344+
using .TemplateExpressionModule: ValidVector, TemplateReturnError
345+
using .ComposableExpressionModule:
346+
ComposableExpression, ValidVectorMixError, ValidVectorAccessError
346347
using .ExpressionBuilderModule: embed_metadata, strip_metadata
347348
using .ParametricExpressionModule: ParametricExpressionSpec
348349
using .TemplateExpressionMacroModule: @template_spec

src/TemplateExpression.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,40 @@ function HOF.make_prefix(::TemplateExpression, ::AbstractOptions, ::Dataset)
631631
return ""
632632
end
633633

634+
struct TemplateReturnError <: Exception end
635+
636+
function Base.showerror(io::IO, ::TemplateReturnError)
637+
return print(
638+
io,
639+
"""
640+
TemplateReturnError: Template expression returned a regular Vector, but ValidVector is required.
641+
642+
Template expressions must return ValidVector for proper handling:
643+
644+
```julia
645+
return ValidVector(my_data, computation_is_valid)
646+
```
647+
648+
The .valid field is used to track whether any upstream computation failed.
649+
It's important to handle this correctly.
650+
651+
Example of manually propagating validity:
652+
653+
```julia
654+
_f_result = f(x1, x2) # Returns ValidVector
655+
_g_result = g(x3) # Returns ValidVector
656+
657+
# Combine results manually and propagate validity
658+
combined_data = _f_result.x .+ _g_result.x
659+
return ValidVector(combined_data, _f_result.valid && _g_result.valid)
660+
```
661+
662+
Note that normally we could simply write `_f_result + _g_result`,
663+
and this would automatically handle the validity and vectorization.
664+
""",
665+
)
666+
end
667+
634668
@stable(
635669
default_mode = "disable",
636670
default_union_limit = 2,
@@ -657,6 +691,10 @@ end
657691
extra_args...,
658692
map(x -> ValidVector(copy(x), true), eachrow(cX)),
659693
)
694+
# Validate that template expression returned a ValidVector
695+
if !(result isa ValidVector)
696+
throw(TemplateReturnError())
697+
end
660698
return result.x, result.valid
661699
end
662700
function (ex::TemplateExpression)(

test/test_composable_expression.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,56 @@ end
129129
x2_val = ValidVector([1.0, 2.0], false)
130130
@test ex(x1_val, x2_val).valid == false
131131
end
132+
133+
@testitem "ValidVector helpful error messages" tags = [:part2] begin
134+
using SymbolicRegression
135+
using SymbolicRegression: ValidVector, ValidVectorMixError, ValidVectorAccessError
136+
137+
vv = ValidVector([1.0, 2.0], true)
138+
v = [3.0, 4.0]
139+
140+
# Helper function to get error message
141+
get_error_msg(err) =
142+
let io = IOBuffer()
143+
Base.showerror(io, err)
144+
String(take!(io))
145+
end
146+
147+
# Test vector arithmetic errors encourage ValidVector wrapping
148+
err_mix = @test_throws ValidVectorMixError vv + v
149+
@test_throws ValidVectorMixError v * vv # Test other direction too
150+
151+
mix_msg = get_error_msg(err_mix.value)
152+
@test contains(
153+
mix_msg,
154+
"ValidVector handles validity checks, auto-vectorization, and batching in template expressions",
155+
)
156+
157+
# Test array access errors mention .x and .valid
158+
err_access = @test_throws ValidVectorAccessError vv[1]
159+
@test_throws ValidVectorAccessError length(vv)
160+
@test_throws ValidVectorAccessError push!(vv, 5.0)
161+
162+
access_msg = get_error_msg(err_access.value)
163+
@test contains(access_msg, "valid_ar.x[1]")
164+
@test contains(access_msg, "valid_ar.valid")
165+
@test contains(access_msg, "length(valid_ar.x)")
166+
@test contains(access_msg, "doesn't support direct array operations")
167+
@test contains(access_msg, "ValidVector handles validity/batching automatically")
168+
end
169+
170+
@testitem "Test Number inputs" tags = [:part2] begin
171+
using SymbolicRegression: ComposableExpression, Node, ValidVector
172+
using DynamicExpressions: OperatorEnum
173+
174+
operators = OperatorEnum(; binary_operators=(+, *))
175+
x1 = ComposableExpression(Node{Float64}(; feature=1); operators)
176+
x2 = ComposableExpression(Node{Float64}(; feature=2); operators)
177+
ex = x1 + x2
178+
179+
@test ex(2.0, 3.0) 5.0
180+
@test isnan(ex(NaN, 3.0))
181+
@test ex(ValidVector([1.0], true), 2.0).x [3.0]
182+
@test ex(ValidVector([1.0, 1.0], true), 2.0).x [3.0, 3.0]
183+
@test ex(ValidVector([1.0, 1.0], false), 2.0).valid == false
184+
end

test/test_template_expression.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,38 @@ end
657657
parse_guesses(PopMember{Float64,Float64}, [bad_guess], [dataset], options)
658658
)
659659
end
660+
661+
@testitem "Template expression return validation" tags = [:part2] begin
662+
using SymbolicRegression:
663+
TemplateReturnError,
664+
ValidVector,
665+
ComposableExpression,
666+
TemplateStructure,
667+
TemplateExpression
668+
using DynamicExpressions: OperatorEnum, Node
669+
670+
operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
671+
x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names=nothing)
672+
673+
# Test that returning a regular vector from template expression throws TemplateReturnError
674+
bad_structure = TemplateStructure{(:f,)}(
675+
((; f), (x,)) -> [1.0, 2.0]; # Returns regular Vector instead of ValidVector
676+
num_features=(; f=1),
677+
)
678+
bad_expr = TemplateExpression(
679+
(; f=x1); structure=bad_structure, operators, variable_names=nothing
680+
)
681+
X = [1.0 2.0]'
682+
683+
function get_error_msg(err)
684+
io = IOBuffer()
685+
Base.showerror(io, err)
686+
return String(take!(io))
687+
end
688+
689+
err = @test_throws TemplateReturnError bad_expr(X)
690+
msg = get_error_msg(err.value)
691+
@test contains(msg, "Template expression returned a regular Vector")
692+
@test contains(msg, "ValidVector is required")
693+
@test contains(msg, "ValidVector(my_data, computation_is_valid)")
694+
end

0 commit comments

Comments
 (0)