Skip to content

Commit 0781cf7

Browse files
committed
feat: handle mixed types in body of template expression
1 parent d05f559 commit 0781cf7

File tree

4 files changed

+89
-6
lines changed

4 files changed

+89
-6
lines changed

src/ComposableExpression.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,22 +260,45 @@ end
260260
# Basically we want to vectorize every single operation on ValidVector,
261261
# so that the user can use it easily.
262262

263+
function _apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
264+
vx = map(_get_value, x)
265+
safe_op = get_safe_op(op)
266+
result = safe_op.(vx...)
267+
return ValidVector(result, is_valid_array(result))
268+
end
269+
263270
function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
264271
if all(_is_valid, x)
265-
vx = map(_get_value, x)
266-
safe_op = get_safe_op(op)
267-
result = safe_op.(vx...)
268-
return ValidVector(result, is_valid_array(result))
272+
return _apply_operator(op, x...)
269273
else
270274
example_vector =
271275
something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector
272-
return ValidVector(_get_value(example_vector), false)
276+
expected_return_type = Base.promote_op(
277+
_apply_operator, typeof(op), map(typeof, x)...
278+
)
279+
if expected_return_type !== Union{} &&
280+
expected_return_type <: ValidVector{<:AbstractArray}
281+
return ValidVector(
282+
_match_eltype(expected_return_type, example_vector.x), false
283+
)::expected_return_type
284+
else
285+
return ValidVector(example_vector.x, false)
286+
end
273287
end
274288
end
275289
_is_valid(x::ValidVector) = x.valid
276290
_is_valid(x) = true
277291
_get_value(x::ValidVector) = x.x
278292
_get_value(x) = x
293+
function _match_eltype(
294+
::Type{<:ValidVector{<:AbstractArray{T1}}}, x::AbstractArray{T2}
295+
) where {T1,T2}
296+
if T1 == T2
297+
return x
298+
else
299+
return Base.Fix1(convert, T1).(x)
300+
end
301+
end
279302

280303
struct ValidVectorMixError <: Exception end
281304
struct ValidVectorAccessError <: Exception end

src/TemplateExpression.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,18 @@ and this would automatically handle the validity and vectorization.
665665
)
666666
end
667667

668+
function _match_input_eltype(
669+
::Type{<:AbstractMatrix{T1}}, result::AbstractVector{T2}
670+
) where {T1,T2}
671+
if T1 != T2 && T1 <: AbstractFloat && T2 <: AbstractFloat
672+
# Just to handle cases where the user might write
673+
# 0.5 in their template spec, but the data is Float32.
674+
return Base.Fix1(convert, T1).(result)
675+
else
676+
return result
677+
end
678+
end
679+
668680
@stable(
669681
default_mode = "disable",
670682
default_union_limit = 2,
@@ -695,7 +707,7 @@ end
695707
if !(result isa ValidVector)
696708
throw(TemplateReturnError())
697709
end
698-
return result.x, result.valid
710+
return _match_input_eltype(typeof(cX), result.x), result.valid
699711
end
700712
function (ex::TemplateExpression)(
701713
X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...

test/test_composable_expression.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,23 @@ end
182182
@test ex(ValidVector([1.0, 1.0], true), 2.0).x [3.0, 3.0]
183183
@test ex(ValidVector([1.0, 1.0], false), 2.0).valid == false
184184
end
185+
186+
@testitem "ValidVector operations with Union{} return type" tags = [:part2] begin
187+
using SymbolicRegression: ValidVector
188+
using SymbolicRegression.ComposableExpressionModule: apply_operator
189+
190+
error_op(::Any, ::Any) = error("This should cause Union{} inference")
191+
192+
x = ValidVector([1.0, 2.0], false)
193+
y = ValidVector([3.0, 4.0], false)
194+
195+
result = apply_operator(error_op, x, y)
196+
@test result isa ValidVector
197+
@test !result.valid
198+
@test result.x == [1.0, 2.0]
199+
200+
a = ValidVector(Float32[1.0, 2.0], false)
201+
b = 1.0
202+
result2 = apply_operator(*, a, b)
203+
@test result2 isa ValidVector{<:AbstractArray{Float64}}
204+
end

test/test_template_expression.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,31 @@ end
692692
@test contains(msg, "ValidVector is required")
693693
@test contains(msg, "ValidVector(my_data, computation_is_valid)")
694694
end
695+
696+
@testitem "Test Float32/Float64 type conversion in TemplateExpression" tags = [:part2] begin
697+
using SymbolicRegression
698+
using SymbolicRegression: eval_loss
699+
700+
template = @template_spec(expressions = (f,)) do x1, x2
701+
0.5 * f(x1, x2) # 0.5 is Float64 literal
702+
end
703+
704+
options = Options(; binary_operators=[+, *, /, -], expression_spec=template)
705+
x1 = ComposableExpression(Node{Float32}(; feature=1); operators=options.operators)
706+
x2 = ComposableExpression(Node{Float32}(; feature=2); operators=options.operators)
707+
f_expr = x1 + x2
708+
709+
template_expr = TemplateExpression(
710+
(; f=f_expr); structure=template.structure, operators=options.operators
711+
)
712+
713+
X = Float32[1.0 2.0; 3.0 4.0]
714+
result = template_expr(X)
715+
@test result isa Vector{Float32}
716+
717+
y = Float32[2.0, 3.0]
718+
dataset = Dataset(X, y)
719+
loss = eval_loss(template_expr, dataset, options)
720+
@test loss isa Float32
721+
@test loss 0.0
722+
end

0 commit comments

Comments
 (0)