@@ -294,54 +294,44 @@ void IpmSolver::runParallel(std::function<void(int)> taskFunction) {
294294}
295295
296296void IpmSolver::initializeCostateTrajectory (const std::vector<AnnotatedTime>& timeDiscretization, const vector_array_t & stateTrajectory,
297- vector_array_t & costateTrajectory) {
297+ vector_array_t & costateTrajectory) const {
298298 const size_t N = static_cast <int >(timeDiscretization.size ()) - 1 ; // size of the input trajectory
299299 costateTrajectory.clear ();
300300 costateTrajectory.reserve (N + 1 );
301- const auto & ocpDefinition = ocpDefinitions_[0 ];
302-
303- constexpr auto request = Request::Cost + Request::SoftConstraint + Request::Approximation;
304- ocpDefinition.preComputationPtr ->requestFinal (request, timeDiscretization[N].time , stateTrajectory[N]);
305- const vector_t lmdN = -approximateFinalCost (ocpDefinition, timeDiscretization[N].time , stateTrajectory[N]).dfdx ;
306301
307302 // Determine till when to use the previous solution
308- scalar_t interpolateCostateTill = timeDiscretization.front ().time ;
309- if (primalSolution_.timeTrajectory_ .size () >= 2 ) {
310- interpolateCostateTill = primalSolution_.timeTrajectory_ .back ();
311- }
303+ const auto interpolateTill =
304+ primalSolution_.timeTrajectory_ .size () < 2 ? timeDiscretization.front ().time : primalSolution_.timeTrajectory_ .back ();
312305
313306 const scalar_t initTime = getIntervalStart (timeDiscretization[0 ]);
314- if (initTime < interpolateCostateTill ) {
307+ if (initTime < interpolateTill ) {
315308 costateTrajectory.push_back (LinearInterpolation::interpolate (initTime, primalSolution_.timeTrajectory_ , costateTrajectory_));
316309 } else {
317- costateTrajectory.push_back (lmdN);
318- // costateTrajectory.push_back(vector_t::Zero(stateTrajectory[0].size()));
310+ costateTrajectory.push_back (vector_t::Zero (stateTrajectory[0 ].size ()));
319311 }
320312
321313 for (int i = 0 ; i < N; i++) {
322314 const scalar_t nextTime = getIntervalEnd (timeDiscretization[i + 1 ]);
323- if (nextTime > interpolateCostateTill) { // Initialize with zero
324- // costateTrajectory.push_back(vector_t::Zero(stateTrajectory[i + 1].size()));
325- costateTrajectory.push_back (lmdN);
326- } else { // interpolate previous solution
315+ if (nextTime < interpolateTill) { // interpolate previous solution
327316 costateTrajectory.push_back (LinearInterpolation::interpolate (nextTime, primalSolution_.timeTrajectory_ , costateTrajectory_));
317+ } else { // Initialize with zero
318+ costateTrajectory.push_back (vector_t::Zero (stateTrajectory[i + 1 ].size ()));
328319 }
329320 }
330321}
331322
332323void IpmSolver::initializeProjectionMultiplierTrajectory (const std::vector<AnnotatedTime>& timeDiscretization,
333- vector_array_t & projectionMultiplierTrajectory) {
324+ vector_array_t & projectionMultiplierTrajectory) const {
334325 const size_t N = static_cast <int >(timeDiscretization.size ()) - 1 ; // size of the input trajectory
335326 projectionMultiplierTrajectory.clear ();
336327 projectionMultiplierTrajectory.reserve (N);
337328 const auto & ocpDefinition = ocpDefinitions_[0 ];
338329
339330 // Determine till when to use the previous solution
340- scalar_t interpolateInputTill = timeDiscretization.front ().time ;
341- if (primalSolution_.timeTrajectory_ .size () >= 2 ) {
342- interpolateInputTill = primalSolution_.timeTrajectory_ [primalSolution_.timeTrajectory_ .size () - 2 ];
343- }
331+ const auto interpolateTill =
332+ primalSolution_.timeTrajectory_ .size () < 2 ? timeDiscretization.front ().time : *std::prev (primalSolution_.timeTrajectory_ .end (), 2 );
344333
334+ // @todo Fix this using trajectory spreading
345335 auto interpolateProjectionMultiplierTrajectory = [&](scalar_t time) -> vector_t {
346336 const size_t numConstraints = ocpDefinition.equalityConstraintPtr ->getNumConstraints (time);
347337 const size_t index = LinearInterpolation::timeSegment (time, primalSolution_.timeTrajectory_ ).first ;
@@ -367,10 +357,10 @@ void IpmSolver::initializeProjectionMultiplierTrajectory(const std::vector<Annot
367357 // Intermediate node
368358 const scalar_t time = getIntervalStart (timeDiscretization[i]);
369359 const size_t numConstraints = ocpDefinition.equalityConstraintPtr ->getNumConstraints (time);
370- if (time > interpolateInputTill) { // Initialize with zero
371- projectionMultiplierTrajectory.push_back (vector_t::Zero (numConstraints));
372- } else { // interpolate previous solution
360+ if (time < interpolateTill) { // interpolate previous solution
373361 projectionMultiplierTrajectory.push_back (interpolateProjectionMultiplierTrajectory (time));
362+ } else { // Initialize with zero
363+ projectionMultiplierTrajectory.push_back (vector_t::Zero (numConstraints));
374364 }
375365 }
376366 }
0 commit comments