Skip to content

Commit e2d7287

Browse files
committed
test: assert new error messages in ComposableExpression
1 parent 6f41fc4 commit e2d7287

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

test/test_composable_expression.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,43 @@ end
130130
@test ex(x1_val, x2_val).valid == false
131131
end
132132

133+
@testitem "ValidVector helpful error messages" tags = [:part2] begin
134+
using SymbolicRegression
135+
using SymbolicRegression: ValidVector, ValidVectorMixError, ValidVectorAccessError
136+
137+
vv = ValidVector([1.0, 2.0], true)
138+
v = [3.0, 4.0]
139+
140+
# Helper function to get error message
141+
get_error_msg(err) =
142+
let io = IOBuffer()
143+
Base.showerror(io, err)
144+
String(take!(io))
145+
end
146+
147+
# Test vector arithmetic errors encourage ValidVector wrapping
148+
err_mix = @test_throws ValidVectorMixError vv + v
149+
@test_throws ValidVectorMixError v * vv # Test other direction too
150+
151+
mix_msg = get_error_msg(err_mix.value)
152+
@test contains(
153+
mix_msg,
154+
"ValidVector handles validity checks, auto-vectorization, and batching in template expressions",
155+
)
156+
157+
# Test array access errors mention .x and .valid
158+
err_access = @test_throws ValidVectorAccessError vv[1]
159+
@test_throws ValidVectorAccessError length(vv)
160+
@test_throws ValidVectorAccessError push!(vv, 5.0)
161+
162+
access_msg = get_error_msg(err_access.value)
163+
@test contains(access_msg, "valid_ar.x[1]")
164+
@test contains(access_msg, "valid_ar.valid")
165+
@test contains(access_msg, "length(valid_ar.x)")
166+
@test contains(access_msg, "doesn't support direct array operations")
167+
@test contains(access_msg, "ValidVector handles validity/batching automatically")
168+
end
169+
133170
@testitem "Test Number inputs" tags = [:part2] begin
134171
using SymbolicRegression: ComposableExpression, Node, ValidVector
135172
using DynamicExpressions: OperatorEnum

test/test_template_expression.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,38 @@ end
657657
parse_guesses(PopMember{Float64,Float64}, [bad_guess], [dataset], options)
658658
)
659659
end
660+
661+
@testitem "Template expression return validation" tags = [:part2] begin
662+
using SymbolicRegression:
663+
TemplateReturnError,
664+
ValidVector,
665+
ComposableExpression,
666+
TemplateStructure,
667+
TemplateExpression
668+
using DynamicExpressions: OperatorEnum, Node
669+
670+
operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
671+
x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names=nothing)
672+
673+
# Test that returning a regular vector from template expression throws TemplateReturnError
674+
bad_structure = TemplateStructure{(:f,)}(
675+
((; f), (x,)) -> [1.0, 2.0]; # Returns regular Vector instead of ValidVector
676+
num_features=(; f=1),
677+
)
678+
bad_expr = TemplateExpression(
679+
(; f=x1); structure=bad_structure, operators, variable_names=nothing
680+
)
681+
X = [1.0 2.0]'
682+
683+
function get_error_msg(err)
684+
io = IOBuffer()
685+
Base.showerror(io, err)
686+
return String(take!(io))
687+
end
688+
689+
err = @test_throws TemplateReturnError bad_expr(X)
690+
msg = get_error_msg(err.value)
691+
@test contains(msg, "Template expression returned a regular Vector")
692+
@test contains(msg, "ValidVector is required")
693+
@test contains(msg, "ValidVector(my_data, computation_is_valid)")
694+
end

0 commit comments

Comments
 (0)