Skip to content

Commit 4a52c37

Browse files
committed
docs: update quickstart to expression type
1 parent 7e0a1b4 commit 4a52c37

File tree

1 file changed

+46
-62
lines changed

1 file changed

+46
-62
lines changed

README.md

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ A dynamic expression is a snippet of code that can change throughout runtime - c
1919
3. It then generates specialized [evaluation kernels](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquation.jl#L29-L66) for the space of potential operators.
2020
4. It also generates kernels for the [first-order derivatives](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquationDerivative.jl#L139-L175), using [Zygote.jl](https://github.com/FluxML/Zygote.jl).
2121
5. DynamicExpressions.jl can also operate on arbitrary other types (vectors, tensors, symbols, strings, or even unions) - see last part below.
22-
23-
It also has import and export functionality with [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl), so you can move your runtime expression into a CAS!
22+
6. It also has import and export functionality with [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl).
2423

2524

2625
## Example
@@ -29,18 +28,17 @@ It also has import and export functionality with [SymbolicUtils.jl](https://gith
2928
using DynamicExpressions
3029

3130
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
31+
variable_names = ["x1", "x2"]
3232

33-
x1 = Node{Float64}(feature=1)
34-
x2 = Node{Float64}(feature=2)
33+
x1 = Expression(Node{Float64}(feature=1); operators, variable_names)
34+
x2 = Expression(Node{Float64}(feature=2); operators, variable_names)
3535

3636
expression = x1 * cos(x2 - 3.2)
3737

3838
X = randn(Float64, 2, 100);
39-
expression(X, operators) # 100-element Vector{Float64}
39+
expression(X) # 100-element Vector{Float64}
4040
```
4141

42-
(We can construct this expression with normal operators, since calling `OperatorEnum()` will `@eval` new functions on `Node` that use the specified enum.)
43-
4442
## Speed
4543

4644
First, what happens if we naively use Julia symbols to define and then evaluate this expression?
@@ -53,27 +51,25 @@ First, what happens if we naively use Julia symbols to define and then evaluate
5351
This is quite slow, meaning it will be hard to quickly search over the space of expressions. Let's see how DynamicExpressions.jl compares:
5452

5553
```julia
56-
@btime expression(X, operators)
57-
# 693 ns
54+
@btime expression(X)
55+
# 607 ns
5856
```
5957

60-
Much faster! And we didn't even need to compile it. (Internally, this is calling `eval_tree_array(expression, X, operators)`).
58+
Much faster! And we didn't even need to compile it. (Internally, this is calling `eval_tree_array(expression, X)`).
6159

6260
If we change `expression` dynamically with a random number generator, it will have the same performance:
6361

6462
```julia
65-
@btime begin
66-
expression.op = rand(1:3) # random operator in [+, -, *]
67-
expression(X, operators)
68-
end
69-
# 842 ns
63+
@btime ex(X) setup=(ex = copy(expression); ex.tree.op = rand(1:3) #= random operator in [+, -, *] =#)
64+
# 640 ns
7065
```
66+
7167
Now, let's see the performance if we had hard-coded these expressions:
7268

7369
```julia
7470
f(X) = X[1, :] .* cos.(X[2, :] .- 3.2)
7571
@btime f(X)
76-
# 708 ns
72+
# 629 ns
7773
```
7874

7975
So, our dynamic expression evaluation is about the same (or even a bit faster) as evaluating a basic hard-coded expression! Let's see if we can optimize the speed of the hard-coded version:
@@ -102,49 +98,37 @@ We can also compute gradients with the same speed:
10298
```julia
10399
using Zygote # trigger extension
104100

105-
operators = OperatorEnum(;
106-
binary_operators=[+, -, *],
107-
unary_operators=[cos],
108-
)
109-
x1 = Node(; feature=1)
110-
x2 = Node(; feature=2)
101+
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
102+
variable_names = ["x1", "x2"]
103+
x1, x2 = (Expression(Node{Float64}(feature=i); operators, variable_names) for i in 1:2)
104+
111105
expression = x1 * cos(x2 - 3.2)
112106
```
113107

114108
We can take the gradient with respect to inputs with simply the `'` character:
115109

116110
```julia
117-
grad = expression'(X, operators)
111+
grad = expression'(X)
118112
```
119113

120114
This is quite fast:
121115

122116
```julia
123-
@btime expression'(X, operators)
124-
# 2894 ns
117+
@btime expression'(X)
118+
# 2333 ns
125119
```
126120

127121
and again, we can change this expression at runtime, without loss in performance!
128122

129123
```julia
130-
@btime begin
131-
expression.op = rand(1:3)
132-
expression'(X, operators)
133-
end
134-
# 3198 ns
124+
@btime ex'(X) setup=(ex = copy(expression); ex.tree.op = rand(1:3))
125+
# 2333 ns
135126
```
136127

137128
Internally, this is calling the `eval_grad_tree_array` function, which performs forward-mode automatic differentiation on the expression tree with Zygote-compiled kernels. We can also compute the derivative with respect to constants:
138129

139130
```julia
140-
result, grad, did_finish = eval_grad_tree_array(expression, X, operators; variable=false)
141-
```
142-
143-
or with respect to variables, and only in a single direction:
144-
145-
```julia
146-
feature = 2
147-
result, grad, did_finish = eval_diff_tree_array(expression, X, operators, feature)
131+
result, grad, did_finish = eval_grad_tree_array(expression, X; variable=false)
148132
```
149133

150134
## Generic types
@@ -154,42 +138,37 @@ result, grad, did_finish = eval_diff_tree_array(expression, X, operators, featur
154138
I'm so glad you asked. `DynamicExpressions.jl` actually will work for **arbitrary types**! However, to work on operators other than real scalars, you need to use the `GenericOperatorEnum <: AbstractOperatorEnum` instead of the normal `OperatorEnum`. Let's try it with strings!
155139

156140
```julia
157-
x1 = Node(String; feature=1)
141+
_x1 = Node{String}(; feature=1)
158142
```
159143

160144
This node, will be used to index input data (whatever it may be) with either `data[feature]` (1D abstract arrays) or `selectdim(data, 1, feature)` (ND abstract arrays). Let's now define some operators to use:
161145

162146
```julia
163-
my_string_func(x::String) = "ello $x"
147+
using DynamicExpressions: @declare_expression_operator
164148

165-
operators = GenericOperatorEnum(;
166-
binary_operators=[*],
167-
unary_operators=[my_string_func]
168-
)
169-
```
149+
my_string_func(x::String) = "ello $x"
150+
@declare_expression_operator(my_string_func, 1)
170151

171-
Now, let's extend our operators to work with the
172-
expression types used by `DynamicExpressions.jl`:
152+
operators = GenericOperatorEnum(; binary_operators=[*], unary_operators=[my_string_func])
173153

174-
```julia
175-
@extend_operators operators
154+
x1 = Expression(_x1; operators, variable_names)
176155
```
177156

178157
Now, let's create an expression:
179158

180159
```julia
181-
tree = "H" * my_string_func(x1)
160+
expression = "H" * my_string_func(x1)
182161
# ^ `(H * my_string_func(x1))`
183162

184-
tree(["World!", "Me?"], operators)
163+
expression(["World!", "Me?"])
185164
# Hello World!
186165
```
187166

188167
So indeed it works for arbitrary types. It is a bit slower due to the potential for type instability, but it's not too bad:
189168

190169
```julia
191-
@btime tree(["Hello", "Me?"], operators)
192-
# 1738 ns
170+
@btime expression(["Hello", "Me?"])
171+
# 103.105 ns (4 allocations: 144 bytes)
193172
```
194173

195174
## Tensors
@@ -200,37 +179,42 @@ Also yes! Let's see:
200179

201180
```julia
202181
using DynamicExpressions
182+
using DynamicExpressions: @declare_expression_operator
203183

204184
T = Union{Float64,Vector{Float64}}
205185

206-
c1 = Node(T; val=0.0) # Scalar constant
207-
c2 = Node(T; val=[1.0, 2.0, 3.0]) # Vector constant
208-
x1 = Node(T; feature=1)
209-
210186
# Some operators on tensors (multiple dispatch can be used for different behavior!)
211187
vec_add(x, y) = x .+ y
212188
vec_square(x) = x .* x
213189

190+
# Enable these operators for DynamicExpressions.jl:
191+
@declare_expression_operator(vec_add, 2)
192+
@declare_expression_operator(vec_square, 1)
193+
214194
# Set up an operator enum:
215195
operators = GenericOperatorEnum(;binary_operators=[vec_add], unary_operators=[vec_square])
216-
@extend_operators operators
217196

218197
# Construct the expression:
219-
tree = vec_add(vec_add(vec_square(x1), c2), c1)
198+
variable_names = ["x1"]
199+
c1 = Expression(Node{T}(; val=0.0); operators, variable_names) # Scalar constant
200+
c2 = Expression(Node{T}(; val=[1.0, 2.0, 3.0]); operators, variable_names) # Vector constant
201+
x1 = Expression(Node{T}(; feature=1); operators, variable_names)
202+
203+
expression = vec_add(vec_add(vec_square(x1), c2), c1)
220204

221205
X = [[-1.0, 5.2, 0.1], [0.0, 0.0, 0.0]]
222206

223207
# Evaluate!
224-
tree(X, operators) # [2.0, 29.04, 3.01]
208+
expression(X) # [2.0, 29.04, 3.01]
225209
```
226210

227211
Note that if an operator is not defined for the particular input, `nothing` will be returned instead.
228212

229213
This is all still pretty fast, too:
230214

231215
```julia
232-
@btime tree(X, operators)
233-
# 2,949 ns
216+
@btime expression(X)
217+
# 461.086 ns (13 allocations: 448 bytes)
234218
@btime eval(:(vec_add(vec_add(vec_square(X[1]), [1.0, 2.0, 3.0]), 0.0)))
235219
# 115,000 ns
236220
```

0 commit comments

Comments
 (0)