@@ -30,7 +30,7 @@ namespace ens {
3030 * documentation on function types included with this distribution or on the
3131 * ensmallen website.
3232 */
33- template <typename VecType = arma::vec>
33+ template <typename VecType = arma::vec> // TODO: remove for ensmallen 4.x
3434class AugLagrangianType
3535{
3636 public:
@@ -50,7 +50,7 @@ class AugLagrangianType
5050 const L_BFGS& lbfgs = L_BFGS());
5151
5252 /* *
53- * Optimize the function. The value '1 ' is used for the initial value of each
53+ * Optimize the function. The value '0 ' is used for the initial value of each
5454 * Lagrange multiplier. To set the Lagrange multipliers yourself, use the
5555 * other overload of Optimize().
5656 *
@@ -67,7 +67,8 @@ class AugLagrangianType
6767 typename MatType,
6868 typename GradType,
6969 typename ... CallbackTypes>
70- typename std::enable_if<IsMatrixType<GradType>::value, bool >::type
70+ typename std::enable_if<IsMatrixType<GradType>::value &&
71+ IsAllNonMatrix<CallbackTypes...>::value, bool >::type
7172 Optimize (LagrangianFunctionType& function,
7273 MatType& coordinates,
7374 CallbackTypes&&... callbacks);
@@ -76,9 +77,10 @@ class AugLagrangianType
7677 template <typename LagrangianFunctionType,
7778 typename MatType,
7879 typename ... CallbackTypes>
79- bool Optimize (LagrangianFunctionType& function,
80- MatType& coordinates,
81- CallbackTypes&&... callbacks)
80+ typename std::enable_if<IsAllNonMatrix<CallbackTypes...>::value, bool >::type
81+ Optimize (LagrangianFunctionType& function,
82+ MatType& coordinates,
83+ CallbackTypes&&... callbacks)
8284 {
8385 return Optimize<LagrangianFunctionType, MatType, MatType,
8486 CallbackTypes...>(function, coordinates,
@@ -97,26 +99,50 @@ class AugLagrangianType
9799 * @tparam CallbackTypes Types of callback functions.
98100 * @param function The function to optimize.
99101 * @param coordinates Output matrix to store the optimized coordinates in.
100- * @param initLambda Vector of initial Lagrange multipliers. Should have
101- * length equal to the number of constraints.
102- * @param initSigma Initial penalty parameter.
102+ * @param lambda Vector containing initial Lagrange multipliers. Should have
103+ * length equal to the number of constraints. This will be overwritten
104+ * with the Lagrange multipliers that are found during optimization.
105+ * @param sigma Initial penalty parameter. This will be overwritten with the
106+ * final penalty value used during optimization.
103107 * @param callbacks Callback functions.
104108 */
105109 template <typename LagrangianFunctionType,
106110 typename MatType,
111+ typename InVecType,
107112 typename GradType,
108113 typename ... CallbackTypes>
114+ [[deprecated(" use Optimize() with non-const lambda/sigma instead" )]]
109115 typename std::enable_if<IsMatrixType<GradType>::value, bool >::type
110116 Optimize (LagrangianFunctionType& function,
111117 MatType& coordinates,
112- const VecType & initLambda,
118+ const InVecType & initLambda,
113119 const double initSigma,
120+ CallbackTypes&&... callbacks)
121+ {
122+ deprecatedLambda = initLambda;
123+ deprecatedSigma = initSigma;
124+ const bool result = Optimize (function, coordinates, this ->deprecatedLambda ,
125+ this ->deprecatedSigma ,
126+ std::forward<CallbackTypes>(callbacks)...);
127+ }
128+
129+ template <typename LagrangianFunctionType,
130+ typename MatType,
131+ typename InVecType,
132+ typename GradType,
133+ typename ... CallbackTypes>
134+ typename std::enable_if<IsMatrixType<GradType>::value, bool >::type
135+ Optimize (LagrangianFunctionType& function,
136+ MatType& coordinates,
137+ InVecType& lambda,
138+ double & sigma,
114139 CallbackTypes&&... callbacks);
115140
116141 // ! Forward the MatType as GradType.
117142 template <typename LagrangianFunctionType,
118143 typename MatType,
119144 typename ... CallbackTypes>
145+ [[deprecated(" use Optimize() with non-const lambda/sigma instead" )]]
120146 bool Optimize (LagrangianFunctionType& function,
121147 MatType& coordinates,
122148 const VecType& initLambda,
@@ -128,20 +154,39 @@ class AugLagrangianType
128154 std::forward<CallbackTypes>(callbacks)...);
129155 }
130156
157+ template <typename LagrangianFunctionType,
158+ typename MatType,
159+ typename InVecType,
160+ typename ... CallbackTypes>
161+ bool Optimize (LagrangianFunctionType& function,
162+ MatType& coordinates,
163+ InVecType& lambda,
164+ double & sigma,
165+ CallbackTypes&&... callbacks)
166+ {
167+ return Optimize<LagrangianFunctionType, MatType, InVecType, MatType,
168+ CallbackTypes...>(function, coordinates, lambda, sigma,
169+ std::forward<CallbackTypes>(callbacks)...);
170+ }
171+
131172 // ! Get the L-BFGS object used for the actual optimization.
132173 const L_BFGS& LBFGS () const { return lbfgs; }
133174 // ! Modify the L-BFGS object used for the actual optimization.
134175 L_BFGS& LBFGS () { return lbfgs; }
135176
136177 // ! Get the Lagrange multipliers.
137- const VecType& Lambda () const { return lambda; }
178+ [[deprecated(" use Optimize() with lambda/sigma parameters instead" )]]
179+ const VecType& Lambda () const { return deprecatedLambda; }
138180 // ! Modify the Lagrange multipliers (i.e. set them before optimization).
139- VecType& Lambda () { return lambda; }
181+ [[deprecated(" use Optimize() with lambda/sigma parameters instead" )]]
182+ VecType& Lambda () { return deprecatedLambda; }
140183
141184 // ! Get the penalty parameter.
142- double Sigma () const { return sigma; }
185+ [[deprecated(" use Optimize() with lambda/sigma parameters instead" )]]
186+ double Sigma () const { return deprecatedSigma; }
143187 // ! Modify the penalty parameter.
144- double & Sigma () { return sigma; }
188+ [[deprecated(" use Optimize() with lambda/sigma parameters instead" )]]
189+ double & Sigma () { return deprecatedSigma; }
145190
146191 // ! Get the maximum iterations
147192 size_t MaxIterations () const { return maxIterations; }
@@ -174,35 +219,37 @@ class AugLagrangianType
174219 // ! Controls early termination of the optimization process.
175220 bool terminate;
176221
222+ // NOTE: these will be removed in ensmallen 4.x!
177223 // ! Lagrange multipliers.
178- VecType lambda;
179-
224+ VecType deprecatedLambda;
180225 // ! Penalty parameter.
181- double sigma ;
226+ double deprecatedSigma ;
182227
183228 /* *
184229 * Internal optimization function: given an initialized AugLagrangianFunction,
185230 * perform the optimization itself.
186231 */
187232 template <typename LagrangianFunctionType,
188233 typename MatType,
234+ typename InVecType,
189235 typename GradType,
190236 typename ... CallbackTypes>
191237 typename std::enable_if<IsMatrixType<GradType>::value, bool >::type
192- Optimize (AugLagrangianFunction<LagrangianFunctionType, VecType >& augfunc,
238+ Optimize (AugLagrangianFunction<LagrangianFunctionType, InVecType >& augfunc,
193239 MatType& coordinates,
194240 CallbackTypes&&... callbacks);
195241
196242 // ! Forward the MatType as GradType.
197243 template <typename LagrangianFunctionType,
198244 typename MatType,
245+ typename InVecType,
199246 typename ... CallbackTypes>
200247 bool Optimize (
201- AugLagrangianFunction<LagrangianFunctionType, VecType >& function,
248+ AugLagrangianFunction<LagrangianFunctionType, InVecType >& function,
202249 MatType& coordinates,
203250 CallbackTypes&&... callbacks)
204251 {
205- return Optimize<LagrangianFunctionType, MatType, MatType,
252+ return Optimize<LagrangianFunctionType, MatType, InVecType, MatType,
206253 CallbackTypes...>(function, coordinates,
207254 std::forward<CallbackTypes>(callbacks)...);
208255 }
0 commit comments