Skip to content

Commit 4cefb1a

Browse files
committed
wip: introduce special assignment operator
1 parent b3ed0b6 commit 4cefb1a

File tree

3 files changed

+86
-5
lines changed

3 files changed

+86
-5
lines changed

src/DynamicExpressions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using DispatchDoctor: @stable, @unstable
1212
include("NodePreallocation.jl")
1313
include("Strings.jl")
1414
include("Evaluate.jl")
15+
include("SpecialOperators.jl")
1516
include("EvaluateDerivative.jl")
1617
include("ChainRules.jl")
1718
include("EvaluationHelpers.jl")
@@ -76,6 +77,7 @@ import .StringsModule: get_op_name
7677
@reexport import .EvaluateModule:
7778
eval_tree_array, differentiable_eval_tree_array, EvalOptions
7879
import .EvaluateModule: ArrayBuffer
80+
@reexport import .SpecialOperatorsModule: AssignOperator
7981
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
8082
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
8183
@reexport import .SimplifyModule: combine_operators, simplify_tree!

src/Evaluate.jl

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ function eval_tree_array(
218218
"Bumper and LoopVectorization features are only compatible with numeric element types",
219219
)
220220
end
221+
if any_special_operators(typeof(operators))
222+
cX = copy(cX)
223+
# TODO: This is dangerous if the element type is mutable
224+
end
221225
if _eval_options.bumper isa Val{true}
222226
return bumper_eval_tree_array(tree, cX, operators, _eval_options)
223227
end
@@ -329,22 +333,26 @@ end
329333
long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
330334
if long_compilation_time
331335
return quote
336+
op = operators.binops[op_idx]
337+
special_operator(op) && return deg2_eval_special(tree, cX, op, eval_options)
332338
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
333339
!result_l.ok && return result_l
334340
@return_on_nonfinite_array(eval_options, result_l.x)
335341
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
336342
!result_r.ok && return result_r
337343
@return_on_nonfinite_array(eval_options, result_r.x)
338344
# op(x, y), for any x or y
339-
deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], eval_options)
345+
deg2_eval(result_l.x, result_r.x, op, eval_options)
340346
end
341347
end
342348
return quote
343349
return Base.Cartesian.@nif(
344350
$nbin,
345351
i -> i == op_idx,
346352
i -> let op = operators.binops[i]
347-
if tree.l.degree == 0 && tree.r.degree == 0
353+
if special_operator(op)
354+
deg2_eval_special(tree, cX, op, eval_options)
355+
elseif tree.l.degree == 0 && tree.r.degree == 0
348356
deg2_l0_r0_eval(tree, cX, op, eval_options)
349357
elseif tree.r.degree == 0
350358
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
@@ -380,13 +388,16 @@ end
380388
eval_options::EvalOptions,
381389
) where {T}
382390
nuna = get_nuna(operators)
391+
special_operators = any_special_operators(operators)
383392
long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
384393
if long_compilation_time
385394
return quote
395+
op = operators.unaops[op_idx]
396+
special_operator(op) && return deg1_eval_special(tree, cX, op, eval_options)
386397
result = _eval_tree_array(tree.l, cX, operators, eval_options)
387398
!result.ok && return result
388399
@return_on_nonfinite_array(eval_options, result.x)
389-
deg1_eval(result.x, operators.unaops[op_idx], eval_options)
400+
deg1_eval(result.x, op, eval_options)
390401
end
391402
end
392403
# This @nif lets us generate an if statement over choice of operator,
@@ -396,13 +407,18 @@ end
396407
$nuna,
397408
i -> i == op_idx,
398409
i -> let op = operators.unaops[i]
399-
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
410+
if special_operator(op)
411+
deg1_eval_special(tree, cX, op, eval_options)
412+
elseif !special_operators &&
413+
tree.l.degree == 2 &&
414+
tree.l.l.degree == 0 &&
415+
tree.l.r.degree == 0
400416
# op(op2(x, y)), where x, y, z are constants or variables.
401417
l_op_idx = tree.l.op
402418
dispatch_deg1_l2_ll0_lr0_eval(
403419
tree, cX, op, l_op_idx, operators.binops, eval_options
404420
)
405-
elseif tree.l.degree == 1 && tree.l.l.degree == 0
421+
elseif !special_operators && tree.l.degree == 1 && tree.l.l.degree == 0
406422
# op(op2(x)), where x is a constant or variable.
407423
l_op_idx = tree.l.op
408424
dispatch_deg1_l1_ll0_eval(
@@ -925,4 +941,10 @@ end
925941
end
926942
end
927943

944+
# Overloaded by SpecialOperators.jl:
945+
function any_special_operators end
946+
function special_operator end
947+
function deg2_eval_special end
948+
function deg1_eval_special end
949+
928950
end

src/SpecialOperators.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
module SpecialOperatorsModule
2+
3+
using ..OperatorEnumModule: OperatorEnum
4+
using ..EvaluateModule: _eval_tree_array, @return_on_nonfinite_array, deg2_eval
5+
6+
import ..EvaluateModule:
7+
special_operator, deg2_eval_special, deg1_eval_special, any_special_operators
8+
9+
function any_special_operators(::Type{OperatorEnum{B,U}}) where {B,U}
10+
return any(special_operator, B.types) || any(special_operator, U.types)
11+
end
12+
13+
# Use this to customize evaluation behavior for operators:
14+
@inline special_operator(::Type) = false
15+
@inline special_operator(f) = special_operator(typeof(f))
16+
17+
# Base.@kwdef struct WhileOperator <: Function
18+
# max_iters::Int = 100
19+
# end
20+
Base.@kwdef struct AssignOperator <: Function
21+
target_register::Int
22+
end
23+
24+
# @inline special_operator(::Type{WhileOperator}) = true
25+
@inline special_operator(::Type{AssignOperator}) = true
26+
27+
# function deg2_eval_special(tree, cX, op::WhileOperator, eval_options)
28+
# cond = tree.l
29+
# body = tree.r
30+
# for _ in 1:(op.max_iters)
31+
# let cond_result = _eval_tree_array(cond, cX, operators, eval_options)
32+
# !cond_result.ok && return cond_result
33+
# @return_on_nonfinite_array(eval_options, cond_result.x)
34+
# end
35+
# let body_result = _eval_tree_array(body, cX, operators, eval_options)
36+
# !body_result.ok && return body_result
37+
# @return_on_nonfinite_array(eval_options, body_result.x)
38+
# # TODO: Need to somehow mask instances
39+
# end
40+
# end
41+
42+
# return get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2))
43+
# end
44+
# TODO: Need to void any instance of buffer when using while loop.
45+
46+
function deg1_eval_special(tree, cX, op::AssignOperator, eval_options)
47+
result = _eval_tree_array(tree.l, cX, operators, eval_options)
48+
!result.ok && return result
49+
@return_on_nonfinite_array(eval_options, result.x)
50+
target_register = op.target_register
51+
@inbounds @simd for i in eachindex(axes(cX, 2))
52+
cX[target_register, i] = result.x[i]
53+
end
54+
return result
55+
end
56+
57+
end

0 commit comments

Comments
 (0)