|
5 | 5 | using DispatchDoctor: @unstable |
6 | 6 |
|
7 | 7 | import SymbolicRegression.PopMemberModule: create_child |
| 8 | + import SymbolicRegression: strip_metadata |
8 | 9 |
|
9 | 10 | # Define a custom PopMember that tracks generation count |
10 | | - mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} |
| 11 | + mutable struct CustomPopMember{T,L,N<:AbstractExpression{T}} <: |
| 12 | + SymbolicRegression.AbstractPopMember{T,L,N} |
11 | 13 | tree::N |
12 | 14 | cost::L |
13 | 15 | loss::L |
|
153 | 155 | ) |
154 | 156 | end |
155 | 157 |
|
| 158 | + function strip_metadata( |
| 159 | + member::CustomPopMember, |
| 160 | + options::SymbolicRegression.AbstractOptions, |
| 161 | + dataset::SymbolicRegression.Dataset{T,L}, |
| 162 | + ) where {T,L} |
| 163 | + complexity = SymbolicRegression.compute_complexity(member.tree, options) |
| 164 | + return CustomPopMember( |
| 165 | + strip_metadata(member.tree, options, dataset), |
| 166 | + member.cost, |
| 167 | + member.loss, |
| 168 | + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), |
| 169 | + complexity, |
| 170 | + member.ref, |
| 171 | + member.parent, |
| 172 | + member.generation, |
| 173 | + ) |
| 174 | + end |
| 175 | + |
156 | 176 | # Test that we can run equation_search with CustomPopMember |
157 | 177 | X = randn(Float32, 2, 100) |
158 | 178 | y = @. X[1, :]^2 - X[2, :] |
|
191 | 211 | @test !isnothing(best_idx) |
192 | 212 | best_member = hall_of_fame.members[best_idx] |
193 | 213 | @test best_member isa CustomPopMember |
| 214 | + |
| 215 | + # Test that guesses API returns CustomPopMember instances |
| 216 | + guess_X = randn(Float32, 2, 80) |
| 217 | + guess_y = @. guess_X[1, :] - guess_X[2, :] |
| 218 | + guess_dataset = SymbolicRegression.Dataset(guess_X, guess_y) |
| 219 | + |
| 220 | + guess_options = SymbolicRegression.Options(; |
| 221 | + binary_operators=[+, -], |
| 222 | + populations=1, |
| 223 | + population_size=5, |
| 224 | + tournament_selection_n=2, |
| 225 | + maxsize=4, |
| 226 | + popmember_type=CustomPopMember, |
| 227 | + deterministic=true, |
| 228 | + seed=1, |
| 229 | + verbosity=0, |
| 230 | + progress=false, |
| 231 | + ) |
| 232 | + |
| 233 | + parsed = SymbolicRegression.parse_guesses( |
| 234 | + CustomPopMember{Float32,Float32}, ["x1 - x2"], [guess_dataset], guess_options |
| 235 | + ) |
| 236 | + |
| 237 | + @test length(parsed) == 1 |
| 238 | + @test length(parsed[1]) == 1 |
| 239 | + parsed_member = parsed[1][1] |
| 240 | + @test parsed_member isa CustomPopMember{Float32,Float32} |
| 241 | + @test isapprox(parsed_member.loss, 0.0f0; atol=1.0f-6) |
| 242 | + |
| 243 | + # Confirm equation_search accepts guesses with CustomPopMember |
| 244 | + hof_from_guess = equation_search( |
| 245 | + guess_X, |
| 246 | + guess_y; |
| 247 | + options=guess_options, |
| 248 | + guesses=["x1 - x2"], |
| 249 | + niterations=0, |
| 250 | + parallelism=:serial, |
| 251 | + ) |
| 252 | + |
| 253 | + @test sum(hof_from_guess.exists) > 0 |
| 254 | + guess_best_idx = findlast(hof_from_guess.exists) |
| 255 | + @test !isnothing(guess_best_idx) |
| 256 | + @test hof_from_guess.members[guess_best_idx] isa CustomPopMember |
194 | 257 | end |
0 commit comments