11module EvaluateDerivativeModule
22
3- import .. NodeModule: AbstractExpressionNode, constructorof
3+ import .. NodeModule: AbstractExpressionNode, constructorof, get_children
44import .. OperatorEnumModule: OperatorEnum
55import .. UtilsModule: fill_similar, ResultOk2
66import .. ValueInterfaceModule: is_valid_array
@@ -66,54 +66,18 @@ function eval_diff_tree_array(
6666end
6767
6868@generated function _eval_diff_tree_array (
69- tree:: AbstractExpressionNode{T} ,
69+ tree:: AbstractExpressionNode{T,D } ,
7070 cX:: AbstractMatrix{T} ,
7171 operators:: OperatorEnum ,
7272 direction:: Integer ,
73- ):: ResultOk2 where {T<: Number }
74- nuna = get_nuna (operators)
75- nbin = get_nbin (operators)
76- deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
77- quote
78- diff_deg1_eval (tree, cX, operators. unaops[op_idx], operators, direction)
79- end
80- else
81- quote
82- Base. Cartesian. @nif (
83- $ nuna,
84- i -> i == op_idx,
85- i ->
86- diff_deg1_eval (tree, cX, operators. unaops[i], operators, direction)
87- )
88- end
89- end
90- deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
91- quote
92- diff_deg2_eval (tree, cX, operators. binops[op_idx], operators, direction)
93- end
94- else
95- quote
96- Base. Cartesian. @nif (
97- $ nbin,
98- i -> i == op_idx,
99- i ->
100- diff_deg2_eval (tree, cX, operators. binops[i], operators, direction)
101- )
102- end
103- end
73+ ):: ResultOk2 where {T<: Number ,D}
10474 quote
105- result = if tree. degree == 0
106- diff_deg0_eval (tree, cX, direction)
107- elseif tree. degree == 1
108- op_idx = tree. op
109- $ deg1_branch
110- else
111- op_idx = tree. op
112- $ deg2_branch
113- end
114- ! result. ok && return result
115- return ResultOk2 (
116- result. x, result. dx, is_valid_array (result. x) && is_valid_array (result. dx)
75+ deg = tree. degree
76+ deg == 0 && return diff_deg0_eval (tree, cX, direction)
77+ Base. Cartesian. @nif (
78+ $ D,
79+ i -> i == deg,
80+ i -> dispatch_diff_degn_eval (tree, cX, Val (i), operators, direction)
11781 )
11882 end
11983end
@@ -130,58 +94,71 @@ function diff_deg0_eval(
13094 return ResultOk2 (const_part, derivative_part, true )
13195end
13296
133- function diff_deg1_eval (
134- tree:: AbstractExpressionNode{T} ,
135- cX:: AbstractMatrix{T} ,
136- op:: F ,
137- operators:: OperatorEnum ,
138- direction:: Integer ,
139- ) where {T<: Number ,F}
140- result = _eval_diff_tree_array (tree. l, cX, operators, direction)
141- ! result. ok && return result
142-
143- # TODO - add type assertions to get better speed:
144- cumulator = result. x
145- dcumulator = result. dx
146- diff_op = _zygote_gradient (op, Val (1 ))
147- @inbounds @simd for j in eachindex (cumulator)
148- x = op (cumulator[j]):: T
149- dx = diff_op (cumulator[j]):: T * dcumulator[j]
150-
151- cumulator[j] = x
152- dcumulator[j] = dx
97+ @generated function diff_degn_eval (
98+ x_cumulators:: NTuple{N} , dx_cumulators:: NTuple{N} , op:: F , direction:: Integer
99+ ) where {N,F}
100+ quote
101+ Base. Cartesian. @nexprs ($ N, i -> begin
102+ x_cumulator_i = x_cumulators[i]
103+ dx_cumulator_i = dx_cumulators[i]
104+ end )
105+ diff_op = _zygote_gradient (op, Val (N))
106+ @inbounds @simd for j in eachindex (x_cumulator_1)
107+ x = Base. Cartesian. @ncall ($ N, op, i -> x_cumulator_i[j])
108+ Base. Cartesian. @ntuple ($ N, i -> grad_i) = Base. Cartesian. @ncall (
109+ $ N, diff_op, i -> x_cumulator_i[j]
110+ )
111+ dx = Base. Cartesian. @ncall ($ N, + , i -> grad_i * dx_cumulator_i[j])
112+ x_cumulator_1[j] = x
113+ dx_cumulator_1[j] = dx
114+ end
115+ return ResultOk2 (x_cumulator_1, dx_cumulator_1, true )
153116 end
154- return result
155117end
156118
157- function diff_deg2_eval (
158- tree:: AbstractExpressionNode{T} ,
119+ @generated function dispatch_diff_degn_eval (
120+ tree:: AbstractExpressionNode{T,D } ,
159121 cX:: AbstractMatrix{T} ,
160- op :: F ,
161- operators:: OperatorEnum ,
122+ :: Val{degree} ,
123+ operators:: OperatorEnum{OPS} ,
162124 direction:: Integer ,
163- ) where {T<: Number ,F}
164- result_l = _eval_diff_tree_array (tree. l, cX, operators, direction)
165- ! result_l. ok && return result_l
166- result_r = _eval_diff_tree_array (tree. r, cX, operators, direction)
167- ! result_r. ok && return result_r
168-
169- ar_l = result_l. x
170- d_ar_l = result_l. dx
171- ar_r = result_r. x
172- d_ar_r = result_r. dx
173- diff_op = _zygote_gradient (op, Val (2 ))
174-
175- @inbounds @simd for j in eachindex (ar_l)
176- x = op (ar_l[j], ar_r[j]):: T
177-
178- first, second = diff_op (ar_l[j], ar_r[j]):: Tuple{T,T}
179- dx = first * d_ar_l[j] + second * d_ar_r[j]
125+ ) where {T<: Number ,D,degree,OPS}
126+ nops = length (OPS. types[degree]. types)
127+
128+ setup = quote
129+ cs = get_children (tree, Val ($ degree))
130+ Base. Cartesian. @nexprs (
131+ $ degree,
132+ i -> begin
133+ result_i = _eval_diff_tree_array (cs[i], cX, operators, direction)
134+ ! result_i. ok && return result_i
135+ end
136+ )
137+ x_cumulators = Base. Cartesian. @ntuple ($ degree, i -> result_i. x)
138+ dx_cumulators = Base. Cartesian. @ntuple ($ degree, i -> result_i. dx)
139+ op_idx = tree. op
140+ end
180141
181- ar_l[j] = x
182- d_ar_l[j] = dx
142+ if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN
143+ quote
144+ $ setup
145+ diff_degn_eval (
146+ x_cumulators, dx_cumulators, operators[$ degree][op_idx], direction
147+ )
148+ end
149+ else
150+ quote
151+ $ setup
152+ Base. Cartesian. @nif (
153+ $ nops,
154+ i -> i == op_idx,
155+ i -> diff_degn_eval (
156+ x_cumulators, dx_cumulators, operators[$ degree][i], direction
157+ )
158+ )
159+ end
183160 end
184- return result_l
161+ # TODO : Need to add the case for many operators
185162end
186163
187164"""
0 commit comments