@@ -18,7 +18,9 @@ function DI.prepare_pushforward_nokwarg(
1818 step_der_var = derivative (f (x_var + t_var * dx_var, context_vars... ), t_var)
1919 pf_var = substitute (step_der_var, Dict (t_var => zero (eltype (x))))
2020
21- res = build_function (pf_var, x_var, dx_var, context_vars... ; expression= Val (false ))
21+ res = build_function (
22+ pf_var, x_var, dx_var, context_vars... ; expression= Val (false ), cse= true
23+ )
2224 (pf_exe, pf_exe!) = if res isa Tuple
2325 res
2426 elseif res isa RuntimeGeneratedFunction
@@ -102,7 +104,7 @@ function DI.prepare_derivative_nokwarg(
102104 context_vars = variablize (contexts)
103105 der_var = derivative (f (x_var, context_vars... ), x_var)
104106
105- res = build_function (der_var, x_var, context_vars... ; expression= Val (false ))
107+ res = build_function (der_var, x_var, context_vars... ; expression= Val (false ), cse = true )
106108 (der_exe, der_exe!) = if res isa Tuple
107109 res
108110 elseif res isa RuntimeGeneratedFunction
@@ -177,7 +179,9 @@ function DI.prepare_gradient_nokwarg(
177179 # Symbolic.gradient only accepts vectors
178180 grad_var = gradient (f (x_var, context_vars... ), vec (x_var))
179181
180- res = build_function (grad_var, vec (x_var), context_vars... ; expression= Val (false ))
182+ res = build_function (
183+ grad_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
184+ )
181185 (grad_exe, grad_exe!) = res
182186 return SymbolicsOneArgGradientPrep (_sig, grad_exe, grad_exe!)
183187end
@@ -254,7 +258,7 @@ function DI.prepare_jacobian_nokwarg(
254258 jacobian (f (x_var, context_vars... ), x_var)
255259 end
256260
257- res = build_function (jac_var, x_var, context_vars... ; expression= Val (false ))
261+ res = build_function (jac_var, x_var, context_vars... ; expression= Val (false ), cse = true )
258262 (jac_exe, jac_exe!) = res
259263 return SymbolicsOneArgJacobianPrep (_sig, jac_exe, jac_exe!)
260264end
@@ -333,7 +337,9 @@ function DI.prepare_hessian_nokwarg(
333337 hessian (f (x_var, context_vars... ), vec (x_var))
334338 end
335339
336- res = build_function (hess_var, vec (x_var), context_vars... ; expression= Val (false ))
340+ res = build_function (
341+ hess_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
342+ )
337343 (hess_exe, hess_exe!) = res
338344
339345 gradient_prep = DI. prepare_gradient_nokwarg (
@@ -420,7 +426,12 @@ function DI.prepare_hvp_nokwarg(
420426 hvp_vec_var = hess_var * vec (dx_var)
421427
422428 res = build_function (
423- hvp_vec_var, vec (x_var), vec (dx_var), context_vars... ; expression= Val (false )
429+ hvp_vec_var,
430+ vec (x_var),
431+ vec (dx_var),
432+ context_vars... ;
433+ expression= Val (false ),
434+ cse= true ,
424435 )
425436 (hvp_exe, hvp_exe!) = res
426437
@@ -508,7 +519,7 @@ function DI.prepare_second_derivative_nokwarg(
508519 der_var = derivative (f (x_var, context_vars... ), x_var)
509520 der2_var = derivative (der_var, x_var)
510521
511- res = build_function (der2_var, x_var, context_vars... ; expression= Val (false ))
522+ res = build_function (der2_var, x_var, context_vars... ; expression= Val (false ), cse = true )
512523 (der2_exe, der2_exe!) = if res isa Tuple
513524 res
514525 elseif res isa RuntimeGeneratedFunction
0 commit comments