Skip to content

Commit 6b90350

Browse files
committed
Merge branch 'master' into n-arity
2 parents 1b542bc + db556dc commit 6b90350

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1240
-423
lines changed

.github/workflows/Documentation.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,5 @@ jobs:
2727
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key
2828
run: |
2929
cd docs
30-
cat ../README.md | sed '/Search options/,$d' > tmp1.md
31-
cat tmp1.md src/index.md > tmp2.md
32-
mv tmp2.md src/index.md
3330
julia --project=. make.jl
3431

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "0.18.2"
4+
version = "0.18.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -15,7 +15,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1616
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1717
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
18-
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
1918

2019
[weakdeps]
2120
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
@@ -44,7 +43,6 @@ PackageExtensionCompat = "1"
4443
PrecompileTools = "1"
4544
Reexport = "1"
4645
SymbolicUtils = "0.19, ^1.0.5, 2"
47-
TestItems = "0.1"
4846
Zygote = "0.6"
4947
julia = "1.6"
5048

benchmark/benchmarks.jl

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
using DynamicExpressions, BenchmarkTools, Random
22

33
# Trigger extensions:
4-
using LoopVectorization
5-
using Bumper
6-
using StrideArrays
7-
using Zygote
4+
using LoopVectorization, Bumper, StrideArrays, Zygote
85

96
if PACKAGE_VERSION < v"0.14.0"
107
@eval using DynamicExpressions: Node as GraphNode
@@ -18,6 +15,14 @@ else
1815
@eval using DynamicExpressions.NodeUtilsModule: is_constant
1916
end
2017

18+
if PACKAGE_VERSION < v"0.18.6"
19+
@eval using DynamicExpressions:
20+
index_constants as index_constant_nodes,
21+
count_constants as count_constant_nodes,
22+
get_constants as get_scalar_constants,
23+
set_constants! as set_scalar_constants!
24+
end
25+
2126
include("../test/tree_gen_utils.jl")
2227

2328
const SUITE = BenchmarkGroup()
@@ -113,15 +118,16 @@ end
113118
PACKAGE_VERSION < v"0.14.0" && return :(copy_node(t; preserve_sharing=preserve_sharing))
114119
return :(copy_node(t)) # Assume type used to infer sharing
115120
end
116-
@generated function get_set_constants!(tree::N) where {T,N<:AbstractExpressionNode{T}}
117-
if !(@isdefined set_constants!)
118-
return :(set_constants(tree, get_constants(tree)))
119-
elseif hasmethod(set_constants!, Tuple{N, Vector{T}})
120-
return :(set_constants!(tree, get_constants(tree)))
121+
@generated function get_set_constants!(tree::N) where {N}
122+
T = eltype(N)
123+
if !(@isdefined set_scalar_constants!)
124+
return :(set_scalar_constants(tree, get_scalar_constants(tree)))
125+
elseif hasmethod(set_scalar_constants!, Tuple{N, Vector{T}})
126+
return :(set_scalar_constants!(tree, get_scalar_constants(tree)))
121127
else
122128
return quote
123-
let (x, refs) = get_constants(tree)
124-
set_constants!(tree, x, refs)
129+
let (x, refs) = get_scalar_constants(tree)
130+
set_scalar_constants!(tree, x, refs)
125131
end
126132
end
127133
end
@@ -141,12 +147,12 @@ function benchmark_utilities()
141147
:combine_operators,
142148
:count_nodes,
143149
:count_depth,
144-
:count_constants,
150+
:count_constant_nodes,
145151
:has_constants,
146152
:has_operators,
147153
:is_constant,
148154
:get_set_constants!,
149-
:index_constants,
155+
:index_constant_nodes,
150156
:string_tree,
151157
:hash,
152158
)
@@ -157,9 +163,9 @@ function benchmark_utilities()
157163
[
158164
:simplify_tree,
159165
:count_nodes,
160-
:count_constants,
166+
:count_constant_nodes,
161167
:get_set_constants!,
162-
:index_constants,
168+
:index_constant_nodes,
163169
:string_tree,
164170
],
165171
)
@@ -207,7 +213,8 @@ function benchmark_utilities()
207213
setup=(
208214
ntrees=100;
209215
n=20;
210-
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees]
216+
rng=Random.MersenneTwister(0);
217+
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32, Node, rng)) for _ in 1:ntrees]
211218
)
212219
)
213220
#! format: on
@@ -216,6 +223,37 @@ function benchmark_utilities()
216223
end
217224
end
218225

226+
# Additional methods
227+
@static if PACKAGE_VERSION >= v"0.18.0"
228+
suite["get_set_constants_parametric"] = @benchmarkable(
229+
[get_set_constants!(ex) for ex in exs],
230+
seconds = 10.0,
231+
setup = (
232+
operators = $operators;
233+
ntrees = 100;
234+
n = 20;
235+
n_features = 5;
236+
n_params = 3;
237+
n_param_classes = 10;
238+
rng = Random.MersenneTwister(0);
239+
exs = [
240+
let tree = gen_random_tree_fixed_size(
241+
n, operators, n_features, Float32, ParametricNode, rng
242+
)
243+
ex = ParametricExpression(
244+
tree;
245+
operators,
246+
variable_names=map(i -> "x$i", 1:n_features),
247+
parameters=randn(rng, Float32, n_params, n_param_classes),
248+
parameter_names=map(i -> "p$i", 1:n_params),
249+
)
250+
ex
251+
end for _ in 1:ntrees
252+
]
253+
)
254+
)
255+
end
256+
219257
return suite
220258
end
221259

docs/make.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,45 @@ using Documenter
22
using DynamicExpressions
33
using Random: AbstractRNG
44

5+
readme = joinpath(@__DIR__, "..", "README.md")
6+
7+
index_content = let r = read(readme, String)
8+
# Wrap img tags in raw HTML blocks:
9+
r = replace(r, r"(<img\s+[^>]+>)" => s"""
10+
11+
```@raw html
12+
\1
13+
```
14+
15+
""")
16+
# Remove end img tags:
17+
r = replace(r, r"</img>" => "")
18+
# Remove div tags:
19+
r = replace(r, r"<div[^>]*>" => "")
20+
# Remove end div tags:
21+
r = replace(r, r"</div>" => "")
22+
23+
top_part = """
24+
# Introduction
25+
26+
"""
27+
28+
bottom_part = """
29+
## Contents
30+
31+
```@contents
32+
Pages = ["utils.md", "api.md", "eval.md"]
33+
```
34+
"""
35+
36+
join((top_part, r, bottom_part), "\n")
37+
end
38+
39+
index_md = joinpath(@__DIR__, "src", "index.md")
40+
open(index_md, "w") do f
41+
write(f, index_content)
42+
end
43+
544
makedocs(;
645
sitename="DynamicExpressions.jl",
746
authors="Miles Cranmer",
@@ -11,4 +50,33 @@ makedocs(;
1150
warnonly=true,
1251
)
1352

53+
# Forward links from old docs:
54+
redirect_page = """
55+
<!DOCTYPE html>
56+
<html lang="en">
57+
<head>
58+
<meta charset="UTF-8">
59+
<title>Redirecting...</title>
60+
<script type="text/javascript">
61+
var fragment = window.location.hash;
62+
window.location.href = "../api/" + fragment;
63+
</script>
64+
</head>
65+
<body>
66+
<p>If you are not redirected automatically, follow this <a id="redirect-link" href="../api/">link to API</a>.</p>
67+
<script type="text/javascript">
68+
document.getElementById('redirect-link').href = "../api/" + window.location.hash;
69+
</script>
70+
</body>
71+
</html>
72+
"""
73+
74+
# Create the types directory and write the redirect page
75+
types_dir = joinpath(@__DIR__, "build", "types")
76+
mkpath(types_dir)
77+
redirect_file = joinpath(types_dir, "index.html")
78+
open(redirect_file, "w") do f
79+
write(f, redirect_page)
80+
end
81+
1482
deploydocs(; repo="github.com/SymbolicML/DynamicExpressions.jl.git")

docs/src/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Gets generated:
2+
index.md

docs/src/types.md renamed to docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Types
1+
# API
22

33
## Operator Enum
44

docs/src/index.md

Lines changed: 0 additions & 6 deletions
This file was deleted.

ext/DynamicExpressionsBumperExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module DynamicExpressionsBumperExt
22

33
using Bumper: @no_escape, @alloc
4-
using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce
5-
using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array
4+
using DynamicExpressions:
5+
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array
6+
using DynamicExpressions.UtilsModule: ResultOk, counttuple
67

78
import DynamicExpressions.ExtensionInterfaceModule:
89
bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
@@ -52,7 +53,7 @@ function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where
5253
cumulator.ok || return cumulator
5354

5455
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
55-
return ResultOk(out, !is_bad_array(out))
56+
return ResultOk(out, is_valid_array(out))
5657
end
5758
function dispatch_kerns!(
5859
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
@@ -63,7 +64,7 @@ function dispatch_kerns!(
6364
out = dispatch_kern2!(
6465
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
6566
)
66-
return ResultOk(out, !is_bad_array(out))
67+
return ResultOk(out, is_valid_array(out))
6768
end
6869

6970
@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}

ext/DynamicExpressionsOptimExt.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using DynamicExpressions:
55
AbstractExpressionNode,
66
filter_map,
77
eval_tree_array,
8-
get_constants,
9-
set_constants!
8+
get_scalar_constants,
9+
set_scalar_constants!,
10+
get_number_type
1011
using Compat: @inline
1112

1213
import Optim: Optim, OptimizationResults, NLSolversBase
@@ -44,9 +45,14 @@ function wrap_func(
4445
function wrapped_f(args::Vararg{Any,M}) where {M}
4546
first_args = args[begin:(end - 1)]
4647
x = args[end]
47-
set_constants!(tree, x, refs)
48+
set_scalar_constants!(tree, x, refs)
4849
return @inline(f(first_args..., tree))
4950
end
51+
# without first args, it looks like this
52+
# function wrapped_f(x)
53+
# set_scalar_constants!(tree, x, refs)
54+
# return @inline(f(tree))
55+
# end
5056
return wrapped_f
5157
end
5258
function wrap_func(
@@ -100,7 +106,8 @@ function Optim.optimize(
100106
if make_copy
101107
tree = copy(tree)
102108
end
103-
x0, refs = get_constants(tree)
109+
110+
x0, refs = get_scalar_constants(tree)
104111
if !isnothing(h!)
105112
throw(
106113
ArgumentError(
@@ -117,7 +124,7 @@ function Optim.optimize(
117124
)
118125
end
119126
minimizer = Optim.minimizer(base_res)
120-
set_constants!(tree, minimizer, refs)
127+
set_scalar_constants!(tree, minimizer, refs)
121128
return ExpressionOptimizationResults(base_res, tree)
122129
end
123130

0 commit comments

Comments
 (0)