@@ -558,6 +558,14 @@ that the user would never write themselves.
558558"""
559559const __DERIVATIVE__ = " __DERIVATIVE__"
560560
561+ # This function helps simplify df_du * du_dx in the commonn case that `du_dx`
562+ # is `true` (when u = x), or `false` (when x ∉ u).
563+ function _univariate_chain_rule (df_du, du_dx)
564+ return MOI. ScalarNonlinearFunction (:* , Any[df_du, du_dx])
565+ end
566+
567+ _univariate_chain_rule (df_du, du_dx:: Bool ) = ifelse (du_dx, df_du, du_dx)
568+
561569function derivative (f:: MOI.ScalarNonlinearFunction , x:: MOI.VariableIndex )
562570 if length (f. args) == 1
563571 u = only (f. args)
@@ -571,28 +579,28 @@ function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
571579 :ifelse ,
572580 Any[MOI. ScalarNonlinearFunction (:>= , Any[u, 0 ]), 1 , - 1 ],
573581 )
574- return MOI . ScalarNonlinearFunction (: * , Any[ df_du, du_dx] )
582+ return _univariate_chain_rule ( df_du, du_dx)
575583 elseif f. head == :sign
576584 return false
577585 elseif f. head == :deg2rad
578586 df_du = deg2rad (1 )
579- return MOI . ScalarNonlinearFunction (: * , Any[ df_du, du_dx] )
587+ return _univariate_chain_rule ( df_du, du_dx)
580588 elseif f. head == :rad2deg
581589 df_du = rad2deg (1 )
582- return MOI . ScalarNonlinearFunction (: * , Any[ df_du, du_dx] )
590+ return _univariate_chain_rule ( df_du, du_dx)
583591 end
584592 for (key, df, _) in MOI. Nonlinear. SYMBOLIC_UNIVARIATE_EXPRESSIONS
585593 if key == f. head
586594 # The chain rule: d(f(g(x))) / dx = f'(g(x)) * g'(x)
587595 df_du = _replace_expression (copy (df), u)
588- return MOI . ScalarNonlinearFunction (: * , Any[ df_du, du_dx] )
596+ return _univariate_chain_rule ( df_du, du_dx)
589597 end
590598 end
591599 # Delay derivative until evaluation. This may result in a later
592600 # UnsupportedNonlinearOperator error, but we can't tell just yet.
593601 d_op = Symbol (__DERIVATIVE__ * " $(f. head) " )
594602 df_du = MOI. ScalarNonlinearFunction (d_op, Any[u])
595- return MOI . ScalarNonlinearFunction (: * , Any[ df_du, du_dx] )
603+ return _univariate_chain_rule ( df_du, du_dx)
596604 end
597605 if f. head == :+
598606 # d/dx(+(args...)) = +(d/dx args)
0 commit comments