Skip to content

Commit 40a8c9e

Browse files
Merged in fix/ddp_init (pull request #644)
Fix/ddp init
2 parents b24c871 + 34634c2 commit 40a8c9e

File tree

8 files changed

+271
-125
lines changed

8 files changed

+271
-125
lines changed

ocs2_ddp/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,13 @@ target_link_libraries(testContinuousTimeLqr
191191
${PROJECT_NAME}
192192
gtest_main
193193
)
194+
195+
catkin_add_gtest(testDdpHelperFunction
196+
test/testDdpHelperFunction.cpp
197+
)
198+
target_link_libraries(testDdpHelperFunction
199+
${Boost_LIBRARIES}
200+
${catkin_LIBRARIES}
201+
${PROJECT_NAME}
202+
gtest_main
203+
)

ocs2_ddp/include/ocs2_ddp/DDP_HelperFunctions.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ PerformanceIndex computeRolloutPerformanceIndex(const scalar_array_t& timeTrajec
8080
scalar_t rolloutTrajectory(RolloutBase& rollout, scalar_t initTime, const vector_t& initState, scalar_t finalTime,
8181
PrimalSolution& primalSolution);
8282

83+
/**
84+
* Extract a primal solution for the range [initTime, finalTime] from a given primal solution. It assumes that the
85+
* given range is within the solution time of input primal solution.
86+
*
87+
* @note: The controller field is ignored.
88+
* @note: The extracted primal solution can have an event time at final time but ignores it at initial time.
89+
*
90+
* @param [in] timePeriod: The time period for which the solution should be extracted.
91+
* @param [in] inputPrimalSolution: The input PrimalSolution
92+
* @param [out] outputPrimalSolution: The output PrimalSolution.
93+
*/
94+
void extractPrimalSolution(const std::pair<scalar_t, scalar_t>& timePeriod, const PrimalSolution& inputPrimalSolution,
95+
PrimalSolution& outputPrimalSolution);
96+
8397
/**
8498
* Computes the integral of the squared (IS) norm of the controller update.
8599
*

ocs2_ddp/include/ocs2_ddp/GaussNewtonDDP.h

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -250,19 +250,15 @@ class GaussNewtonDDP : public SolverBase {
250250
std::vector<std::pair<int, int>> getPartitionIntervalsFromTimeTrajectory(const scalar_array_t& timeTrajectory, int numWorkers);
251251

252252
/**
253-
* Forward integrate the system dynamics with given controller and operating trajectories. In general, it uses the
254-
* given control policies and initial state, to integrate the system dynamics in the time period [initTime, finalTime].
255-
* However, if the provided controller does not cover the period [initTime, finalTime], it extrapolates (zero-order)
256-
* the controller until the next event time where after it uses the operating trajectories.
253+
* Forward integrate the system dynamics with given controller in primalSolution and operating trajectories. In general, it uses
254+
* the given control policies and initial state, to integrate the system dynamics in the time period [initTime, finalTime].
255+
* However, if the provided controller does not cover the period [initTime, finalTime], it will use the controller till the
256+
* final time of the controller and after it uses the operating trajectories.
257257
*
258-
* Attention: Do NOT pass the controllerPtr of the same primalData used for the first parameter to the second parameter, as all
259-
* member variables(including controller) of primal data will be cleared.
260-
*
261-
* @param [out] primalData: primalData
262-
* @param [in] controller: nominal controller used to rollout (time, state, input...) trajectories
263-
* @param [in] workerIndex: working thread (default is 0).
258+
* @param [in, out] primalSolution: The resulting state-input trajectory. The primal solution is initialized with the controller
259+
* and the modeSchedule. However, for StateTriggered Rollout the modeSchedule can be overwritten.
264260
*/
265-
void rolloutInitialTrajectory(PrimalDataContainer& primalData, ControllerBase* controller, size_t workerIndex = 0);
261+
void rolloutInitialTrajectory(PrimalSolution& primalSolution);
266262

267263
/**
268264
* Calculates the controller. This method uses the following variables. The method modifies unoptimizedController_.
@@ -364,9 +360,7 @@ class GaussNewtonDDP : public SolverBase {
364360
std::pair<bool, std::string> checkConvergence(bool isInitalControllerEmpty, const PerformanceIndex& previousPerformanceIndex,
365361
const PerformanceIndex& currentPerformanceIndex) const;
366362

367-
void runImpl(scalar_t initTime, const vector_t& initState, scalar_t finalTime) override {
368-
runImpl(initTime, initState, finalTime, nullptr);
369-
}
363+
void runImpl(scalar_t initTime, const vector_t& initState, scalar_t finalTime) override;
370364

371365
void runImpl(scalar_t initTime, const vector_t& initState, scalar_t finalTime, const ControllerBase* externalControllerPtr) override;
372366

@@ -407,8 +401,8 @@ class GaussNewtonDDP : public SolverBase {
407401
PerformanceIndex performanceIndex_;
408402
std::vector<PerformanceIndex> performanceIndexHistory_;
409403

404+
std::unique_ptr<RolloutBase> initializerRolloutPtr_;
410405
std::vector<std::unique_ptr<RolloutBase>> dynamicsForwardRolloutPtrStock_;
411-
std::vector<std::unique_ptr<RolloutBase>> initializerRolloutPtrStock_;
412406

413407
// optimized data
414408
DualSolution optimizedDualSolution_;

ocs2_ddp/include/ocs2_ddp/SLQ.h

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3232
#include <ocs2_core/integration/Integrator.h>
3333
#include <ocs2_core/integration/SystemEventHandler.h>
3434

35-
#include "GaussNewtonDDP.h"
36-
#include "riccati_equations/ContinuousTimeRiccatiEquations.h"
35+
#include "ocs2_ddp/GaussNewtonDDP.h"
36+
#include "ocs2_ddp/riccati_equations/ContinuousTimeRiccatiEquations.h"
3737

3838
namespace ocs2 {
3939

@@ -89,23 +89,6 @@ class SLQ final : public GaussNewtonDDP {
8989
scalar_array_t& SsNormalizedTime, size_array_t& SsNormalizedPostEventIndices,
9090
vector_array_t& allSsTrajectory);
9191

92-
/**
93-
* Integrates the riccati equation and freely selects the time nodes for the value function.
94-
*
95-
* @param riccatiIntegrator [in] : Riccati integrator object
96-
* @param riccatiEquation [in] : Riccati equation object
97-
* @param nominalTimeTrajectory [in] : time trajectory produced in the forward rollout.
98-
* @param nominalEventsPastTheEndIndices [in] : Indices into nominalTimeTrajectory to point to times right after event times
99-
* @param allSsFinal [in] : Final value of the value function.
100-
* @param SsNormalizedTime [out] : Time trajectory of the value function.
101-
* @param SsNormalizedPostEventIndices [out] : Indices into SsNormalizedTime to point to times right after event times
102-
* @param allSsTrajectory [out] : Value function in vector format.
103-
*/
104-
void integrateRiccatiEquationAdaptiveTime(IntegratorBase& riccatiIntegrator, ContinuousTimeRiccatiEquations& riccatiEquation,
105-
const scalar_array_t& nominalTimeTrajectory, const size_array_t& nominalEventsPastTheEndIndices,
106-
vector_t allSsFinal, scalar_array_t& SsNormalizedTime,
107-
size_array_t& SsNormalizedPostEventIndices, vector_array_t& allSsTrajectory);
108-
10992
/****************
11093
*** Variables **
11194
****************/

ocs2_ddp/src/DDP_HelperFunctions.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,29 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3434

3535
#include <ocs2_core/PreComputation.h>
3636
#include <ocs2_core/integration/TrapezoidalIntegration.h>
37+
#include <ocs2_core/misc/LinearInterpolation.h>
3738
#include <ocs2_oc/approximate_model/LinearQuadraticApproximator.h>
3839

3940
namespace ocs2 {
4041

42+
namespace {
43+
template <typename DataType>
44+
void copySegment(const LinearInterpolation::index_alpha_t& indexAlpha0, const LinearInterpolation::index_alpha_t& indexAlpha1,
45+
const std::vector<DataType>& inputTrajectory, std::vector<DataType>& outputTrajectory) {
46+
outputTrajectory.clear();
47+
outputTrajectory.resize(2 + indexAlpha1.first - indexAlpha0.first);
48+
49+
if (!outputTrajectory.empty()) {
50+
outputTrajectory.front() = LinearInterpolation::interpolate(indexAlpha0, inputTrajectory);
51+
if (indexAlpha1.first >= indexAlpha0.first) {
52+
std::copy(inputTrajectory.begin() + indexAlpha0.first + 1, inputTrajectory.begin() + indexAlpha1.first + 1,
53+
outputTrajectory.begin() + 1);
54+
}
55+
outputTrajectory.back() = LinearInterpolation::interpolate(indexAlpha1, inputTrajectory);
56+
}
57+
}
58+
} // unnamed namespace
59+
4160
/******************************************************************************************************/
4261
/******************************************************************************************************/
4362
/******************************************************************************************************/
@@ -165,6 +184,74 @@ scalar_t rolloutTrajectory(RolloutBase& rollout, scalar_t initTime, const vector
165184
return (finalTime - initTime) / static_cast<scalar_t>(primalSolution.timeTrajectory_.size());
166185
}
167186

187+
/******************************************************************************************************/
188+
/******************************************************************************************************/
189+
/******************************************************************************************************/
190+
void extractPrimalSolution(const std::pair<scalar_t, scalar_t>& timePeriod, const PrimalSolution& inputPrimalSolution,
191+
PrimalSolution& outputPrimalSolution) {
192+
// no controller
193+
if (outputPrimalSolution.controllerPtr_ != nullptr) {
194+
outputPrimalSolution.controllerPtr_->clear();
195+
}
196+
// for none StateTriggeredRollout initialize modeSchedule
197+
outputPrimalSolution.modeSchedule_ = inputPrimalSolution.modeSchedule_;
198+
199+
// create alias
200+
auto& timeTrajectory = outputPrimalSolution.timeTrajectory_;
201+
auto& stateTrajectory = outputPrimalSolution.stateTrajectory_;
202+
auto& inputTrajectory = outputPrimalSolution.inputTrajectory_;
203+
auto& postEventIndices = outputPrimalSolution.postEventIndices_;
204+
205+
/*
206+
* Find the indexAlpha pair for interpolation. The interpolation function uses the std::lower_bound while ignoring the initial
207+
* time event, we should use std::upper_bound. Therefore at the first step, we check if for the case where std::upper_bound
208+
* would have give a different solution (index_alpha_t::second = 0) and correct the pair. Then, in the second step, we check
209+
* whether the index_alpha_t::first is a pre-event index. If yes, we move index_alpha_t::first to the post-event index.
210+
*/
211+
const auto indexAlpha0 = [&]() {
212+
const auto lowerBoundIndexAlpha = LinearInterpolation::timeSegment(timePeriod.first, inputPrimalSolution.timeTrajectory_);
213+
214+
const auto upperBoundIndexAlpha = numerics::almost_eq(lowerBoundIndexAlpha.second, 0.0)
215+
? LinearInterpolation::index_alpha_t{lowerBoundIndexAlpha.first + 1, 1.0}
216+
: lowerBoundIndexAlpha;
217+
const auto it = std::find(inputPrimalSolution.postEventIndices_.cbegin(), inputPrimalSolution.postEventIndices_.cend(),
218+
upperBoundIndexAlpha.first + 1);
219+
if (it == inputPrimalSolution.postEventIndices_.cend()) {
220+
return upperBoundIndexAlpha;
221+
} else {
222+
return LinearInterpolation::index_alpha_t{upperBoundIndexAlpha.first + 1, 1.0};
223+
}
224+
}();
225+
const auto indexAlpha1 = LinearInterpolation::timeSegment(timePeriod.second, inputPrimalSolution.timeTrajectory_);
226+
227+
// time
228+
copySegment(indexAlpha0, indexAlpha1, inputPrimalSolution.timeTrajectory_, timeTrajectory);
229+
230+
// state
231+
copySegment(indexAlpha0, indexAlpha1, inputPrimalSolution.stateTrajectory_, stateTrajectory);
232+
233+
// input
234+
copySegment(indexAlpha0, indexAlpha1, inputPrimalSolution.inputTrajectory_, inputTrajectory);
235+
236+
// If the pre-event index is within the range we accept the event
237+
postEventIndices.clear();
238+
for (const auto& postIndex : inputPrimalSolution.postEventIndices_) {
239+
if (postIndex > static_cast<size_t>(indexAlpha0.first) && inputPrimalSolution.timeTrajectory_[postIndex - 1] <= timePeriod.second) {
240+
postEventIndices.push_back(postIndex - static_cast<size_t>(indexAlpha0.first));
241+
}
242+
}
243+
244+
// If there is an event at final time, it misses its pair (due to indexAlpha1 and copySegment)
245+
if (!postEventIndices.empty() && postEventIndices.back() == timeTrajectory.size()) {
246+
constexpr auto eps = numeric_traits::weakEpsilon<scalar_t>();
247+
const auto indexAlpha2 = LinearInterpolation::timeSegment(timePeriod.second + eps, inputPrimalSolution.timeTrajectory_);
248+
249+
timeTrajectory.push_back(std::min(timePeriod.second + eps, timePeriod.second));
250+
stateTrajectory.push_back(LinearInterpolation::interpolate(indexAlpha2, inputPrimalSolution.stateTrajectory_));
251+
inputTrajectory.push_back(LinearInterpolation::interpolate(indexAlpha2, inputPrimalSolution.inputTrajectory_));
252+
}
253+
}
254+
168255
/******************************************************************************************************/
169256
/******************************************************************************************************/
170257
/******************************************************************************************************/

0 commit comments

Comments
 (0)