|
| 1 | +# supposition_utils.jl |
| 2 | +# |
| 3 | +# Helper that builds a Supposition generator returning fully-random |
| 4 | +# DynamicExpressions.Expression objects whose node type is Node{T,D}. |
| 5 | +# D is inferred from `operators`. |
| 6 | + |
| 7 | +module SuppositionUtils |
| 8 | + |
| 9 | +using Supposition: Data |
| 10 | +using DynamicExpressions: Node, Expression, OperatorEnum |
| 11 | +using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals! |
| 12 | +empty_all_globals!() |
| 13 | + |
| 14 | +function make_expression_generator( |
| 15 | + ::Type{T}; |
| 16 | + num_features::Int=5, |
| 17 | + operators::OperatorEnum=OperatorEnum(((abs, cos), (+, -, *, /))), |
| 18 | + max_layers::Int=3, |
| 19 | +) where {T} |
| 20 | + D = length(operators.ops) |
| 21 | + |
| 22 | + val_gen = Data.Floats{T}(; nans=false, infs=false) |
| 23 | + val_node_gen = map(v -> Node{T,D}(; val=v), val_gen) |
| 24 | + |
| 25 | + feature_gen = Data.SampledFrom(1:num_features) |
| 26 | + feature_node_gen = map(i -> Node{T,D}(; feature=i), feature_gen) |
| 27 | + |
| 28 | + leaf_gen = val_node_gen | feature_node_gen |
| 29 | + |
| 30 | + wrapper_funcs = ntuple( |
| 31 | + degree -> let op_list = operators[degree] |
| 32 | + op_gen = Data.SampledFrom(1:length(op_list)) |
| 33 | + |
| 34 | + child -> map( |
| 35 | + (op_idx, args...) -> Node{T,D}(; op=op_idx, children=args), |
| 36 | + op_gen, |
| 37 | + ntuple(_ -> child, degree)..., |
| 38 | + ) |
| 39 | + end, |
| 40 | + Val(D), |
| 41 | + ) |
| 42 | + expr_wrap(child) = foldl(|, (w(child) for w in wrapper_funcs)) |
| 43 | + tree_gen = Data.Recursive(leaf_gen, expr_wrap; max_layers) |
| 44 | + return map( |
| 45 | + t -> Expression(t; operators, variable_names=["x$i" for i in 1:num_features]), |
| 46 | + tree_gen, |
| 47 | + ) |
| 48 | +end |
| 49 | + |
| 50 | +# inside module SuppositionUtils |
| 51 | +function make_input_matrix_generator( |
| 52 | + ::Type{T}=Float64; n_features::Int=5, min_batch::Int=1, max_batch::Int=16 |
| 53 | +) where {T} |
| 54 | + elem_gen = Data.Floats{T}(; nans=false, infs=false) |
| 55 | + batch_gen = Data.Integers(min_batch, max_batch) |
| 56 | + |
| 57 | + Data.bind(batch_gen) do bs |
| 58 | + vec_len = n_features * bs |
| 59 | + vec_gen = Data.Vectors(elem_gen; min_size=vec_len, max_size=vec_len) |
| 60 | + Data.map(v -> reshape(v, n_features, bs), vec_gen) |
| 61 | + end |
| 62 | +end |
| 63 | + |
| 64 | +end |
| 65 | + |
| 66 | +using .SuppositionUtils: make_input_matrix_generator, make_expression_generator |
0 commit comments