Skip to content

Commit 5a3a956

Browse files
committed
forward the precomputation to the getParameter functions
1 parent 0d60cd6 commit 5a3a956

12 files changed

+63
-50
lines changed

ocs2_core/include/ocs2_core/constraint/StateConstraintCppAd.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ class StateConstraintCppAd : public StateConstraint {
5656
bool recompileLibraries = true, bool verbose = true);
5757

5858
/** Get the parameter vector */
59-
virtual vector_t getParameters(scalar_t time) const { return vector_t(0); };
59+
virtual vector_t getParameters(scalar_t time, const PreComputation& /* preComputation */) const { return vector_t(0); };
6060

6161
/** Constraint evaluation */
62-
vector_t getValue(scalar_t time, const vector_t& state, const PreComputation& /* preComputation */) const override;
62+
vector_t getValue(scalar_t time, const vector_t& state, const PreComputation& preComputation) const override;
6363
VectorFunctionLinearApproximation getLinearApproximation(scalar_t time, const vector_t& state,
64-
const PreComputation& /* preComputation */) const override;
64+
const PreComputation& preComputation) const override;
6565
VectorFunctionQuadraticApproximation getQuadraticApproximation(scalar_t time, const vector_t& state,
66-
const PreComputation& /* preComputation */) const override;
66+
const PreComputation& preComputation) const override;
6767

6868
protected:
6969
StateConstraintCppAd(const StateConstraintCppAd& rhs);

ocs2_core/include/ocs2_core/constraint/StateInputConstraintCppAd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class StateInputConstraintCppAd : public StateInputConstraint {
5757
const std::string& modelFolder = "/tmp/ocs2", bool recompileLibraries = true, bool verbose = true);
5858

5959
/** Get the parameter vector */
60-
virtual vector_t getParameters(scalar_t time) const { return vector_t(0); };
60+
virtual vector_t getParameters(scalar_t time, const PreComputation& /* preComputation */) const { return vector_t(0); };
6161

6262
/** Constraint evaluation */
6363
vector_t getValue(scalar_t time, const vector_t& state, const vector_t& input, const PreComputation& /* preComputation */) const override;

ocs2_core/include/ocs2_core/cost/StateCostCppAd.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ class StateCostCppAd : public StateCost {
5656
bool recompileLibraries = true, bool verbose = true);
5757

5858
/* Get the parameter vector */
59-
virtual vector_t getParameters(scalar_t time, const TargetTrajectories& targetTrajectories) const { return vector_t(0); };
59+
virtual vector_t getParameters(scalar_t time, const TargetTrajectories& targetTrajectories,
60+
const PreComputation& /* preComputation */) const {
61+
return vector_t(0);
62+
};
6063

6164
/* Cost evaluation */
6265
scalar_t getValue(scalar_t time, const vector_t& state, const TargetTrajectories& targetTrajectories,

ocs2_core/include/ocs2_core/cost/StateInputCostCppAd.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,17 @@ class StateInputCostCppAd : public StateInputCost {
5757
const std::string& modelFolder = "/tmp/ocs2", bool recompileLibraries = true, bool verbose = true);
5858

5959
/** Get the parameter vector */
60-
virtual vector_t getParameters(scalar_t time, const TargetTrajectories& targetTrajectories) const { return vector_t(0); };
60+
virtual vector_t getParameters(scalar_t time, const TargetTrajectories& targetTrajectories,
61+
const PreComputation& /* preComputation */) const {
62+
return vector_t(0);
63+
};
6164

6265
/** Cost evaluation */
6366
scalar_t getValue(scalar_t time, const vector_t& state, const vector_t& input, const TargetTrajectories& targetTrajectories,
64-
const PreComputation&) const override;
67+
const PreComputation& preComputation) const override;
6568
ScalarFunctionQuadraticApproximation getQuadraticApproximation(scalar_t time, const vector_t& state, const vector_t& input,
6669
const TargetTrajectories& targetTrajectories,
67-
const PreComputation&) const override;
70+
const PreComputation& preComputation) const override;
6871

6972
protected:
7073
StateInputCostCppAd(const StateInputCostCppAd& rhs);

ocs2_core/include/ocs2_core/cost/StateInputGaussNewtonCostAd.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,17 @@ class StateInputCostGaussNewtonAd : public StateInputCost {
6565
const std::string& modelFolder = "/tmp/ocs2", bool recompileLibraries = true, bool verbose = true);
6666

6767
/** Get the parameter vector */
68-
virtual vector_t getParameters(scalar_t time, const TargetTrajectories& targetTrajectories) const { return vector_t(0); };
68+
virtual vector_t getParameters(scalar_t time, const TargetTrajectories& targetTrajectories,
69+
const PreComputation& /* preComputation */) const {
70+
return vector_t(0);
71+
};
6972

7073
/** Cost evaluation */
7174
scalar_t getValue(scalar_t time, const vector_t& state, const vector_t& input, const TargetTrajectories& targetTrajectories,
72-
const PreComputation&) const override;
75+
const PreComputation& preComputation) const override;
7376
ScalarFunctionQuadraticApproximation getQuadraticApproximation(scalar_t time, const vector_t& state, const vector_t& input,
7477
const TargetTrajectories& targetTrajectories,
75-
const PreComputation&) const override;
78+
const PreComputation& preComputation) const override;
7679

7780
protected:
7881
StateInputCostGaussNewtonAd(const StateInputCostGaussNewtonAd& rhs);

ocs2_core/include/ocs2_core/dynamics/SystemDynamicsBaseAD.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ class SystemDynamicsBaseAD : public SystemDynamicsBase {
6363
void initialize(size_t stateDim, size_t inputDim, const std::string& modelName, const std::string& modelFolder = "/tmp/ocs2",
6464
bool recompileLibraries = true, bool verbose = true);
6565

66-
vector_t computeFlowMap(scalar_t t, const vector_t& x, const vector_t& u, const PreComputation&) final;
66+
vector_t computeFlowMap(scalar_t t, const vector_t& x, const vector_t& u, const PreComputation& preComputation) final;
6767

68-
vector_t computeJumpMap(scalar_t t, const vector_t& x, const PreComputation&) final;
68+
vector_t computeJumpMap(scalar_t t, const vector_t& x, const PreComputation& preComputation) final;
6969

7070
vector_t computeGuardSurfaces(scalar_t t, const vector_t& x) final;
7171

72-
VectorFunctionLinearApproximation linearApproximation(scalar_t t, const vector_t& x, const vector_t& u, const PreComputation&) final;
72+
VectorFunctionLinearApproximation linearApproximation(scalar_t t, const vector_t& x, const vector_t& u,
73+
const PreComputation& preComputation) final;
7374

74-
VectorFunctionLinearApproximation jumpMapLinearApproximation(scalar_t t, const vector_t& x, const PreComputation&) final;
75+
VectorFunctionLinearApproximation jumpMapLinearApproximation(scalar_t t, const vector_t& x, const PreComputation& preComputation) final;
7576

7677
VectorFunctionLinearApproximation guardSurfacesLinearApproximation(scalar_t t, const vector_t& x, const vector_t& u) final;
7778

@@ -106,7 +107,7 @@ class SystemDynamicsBaseAD : public SystemDynamicsBase {
106107
* @param [in] time: Current time.
107108
* @return The parameters to be set in the flow map at the start of the horizon
108109
*/
109-
virtual vector_t getFlowMapParameters(scalar_t time) const { return vector_t(0); }
110+
virtual vector_t getFlowMapParameters(scalar_t time, const PreComputation& /* preComputation */) const { return vector_t(0); }
110111

111112
/**
112113
* Number of parameters for system flow map.
@@ -131,7 +132,7 @@ class SystemDynamicsBaseAD : public SystemDynamicsBase {
131132
* @param [in] time: Current time.
132133
* @return The parameters to be set in the jump map
133134
*/
134-
virtual vector_t getJumpMapParameters(scalar_t time) const { return vector_t(0); }
135+
virtual vector_t getJumpMapParameters(scalar_t time, const PreComputation& /* preComputation */) const { return vector_t(0); }
135136

136137
/**
137138
* Number of parameters for jump map.

ocs2_core/src/constraint/StateConstraintCppAd.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,21 @@ StateConstraintCppAd::StateConstraintCppAd(const StateConstraintCppAd& rhs)
6767
/******************************************************************************************************/
6868
/******************************************************************************************************/
6969
/******************************************************************************************************/
70-
vector_t StateConstraintCppAd::getValue(scalar_t time, const vector_t& state, const PreComputation&) const {
70+
vector_t StateConstraintCppAd::getValue(scalar_t time, const vector_t& state, const PreComputation& preComputation) const {
7171
vector_t tapedTimeState(1 + state.rows());
7272
tapedTimeState << time, state;
73-
return adInterfacePtr_->getFunctionValue(tapedTimeState, getParameters(time));
73+
return adInterfacePtr_->getFunctionValue(tapedTimeState, getParameters(time, preComputation));
7474
}
7575

7676
/******************************************************************************************************/
7777
/******************************************************************************************************/
7878
/******************************************************************************************************/
7979
VectorFunctionLinearApproximation StateConstraintCppAd::getLinearApproximation(scalar_t time, const vector_t& state,
80-
const PreComputation&) const {
80+
const PreComputation& preComputation) const {
8181
VectorFunctionLinearApproximation constraint;
8282

8383
const size_t stateDim = state.rows();
84-
const vector_t params = getParameters(time);
84+
const vector_t params = getParameters(time, preComputation);
8585
vector_t tapedTimeState(1 + stateDim);
8686
tapedTimeState << time, state;
8787

@@ -96,15 +96,15 @@ VectorFunctionLinearApproximation StateConstraintCppAd::getLinearApproximation(s
9696
/******************************************************************************************************/
9797
/******************************************************************************************************/
9898
VectorFunctionQuadraticApproximation StateConstraintCppAd::getQuadraticApproximation(scalar_t time, const vector_t& state,
99-
const PreComputation&) const {
99+
const PreComputation& preComputation) const {
100100
if (getOrder() != ConstraintOrder::Quadratic) {
101101
throw std::runtime_error("[StateConstraintCppAd] Quadratic approximation not supported!");
102102
}
103103

104104
VectorFunctionQuadraticApproximation constraint;
105105

106106
const size_t stateDim = state.rows();
107-
const vector_t params = getParameters(time);
107+
const vector_t params = getParameters(time, preComputation);
108108
vector_t tapedTimeState(1 + stateDim);
109109
tapedTimeState << time, state;
110110

ocs2_core/src/constraint/StateInputConstraintCppAd.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,24 @@ StateInputConstraintCppAd::StateInputConstraintCppAd(const StateInputConstraintC
6868
/******************************************************************************************************/
6969
/******************************************************************************************************/
7070
/******************************************************************************************************/
71-
vector_t StateInputConstraintCppAd::getValue(scalar_t time, const vector_t& state, const vector_t& input, const PreComputation&) const {
71+
vector_t StateInputConstraintCppAd::getValue(scalar_t time, const vector_t& state, const vector_t& input,
72+
const PreComputation& preComputation) const {
7273
vector_t tapedTimeStateInput(1 + state.rows() + input.rows());
7374
tapedTimeStateInput << time, state, input;
74-
return adInterfacePtr_->getFunctionValue(tapedTimeStateInput, getParameters(time));
75+
return adInterfacePtr_->getFunctionValue(tapedTimeStateInput, getParameters(time, preComputation));
7576
}
7677

7778
/******************************************************************************************************/
7879
/******************************************************************************************************/
7980
/******************************************************************************************************/
8081
VectorFunctionLinearApproximation StateInputConstraintCppAd::getLinearApproximation(scalar_t time, const vector_t& state,
81-
const vector_t& input, const PreComputation&) const {
82+
const vector_t& input,
83+
const PreComputation& preComputation) const {
8284
VectorFunctionLinearApproximation constraint;
8385

8486
const size_t stateDim = state.rows();
8587
const size_t inputDim = input.rows();
86-
const vector_t params = getParameters(time);
88+
const vector_t params = getParameters(time, preComputation);
8789
vector_t tapedTimeStateInput(1 + stateDim + inputDim);
8890
tapedTimeStateInput << time, state, input;
8991

@@ -100,7 +102,7 @@ VectorFunctionLinearApproximation StateInputConstraintCppAd::getLinearApproximat
100102
/******************************************************************************************************/
101103
VectorFunctionQuadraticApproximation StateInputConstraintCppAd::getQuadraticApproximation(scalar_t time, const vector_t& state,
102104
const vector_t& input,
103-
const PreComputation&) const {
105+
const PreComputation& preComputation) const {
104106
if (getOrder() != ConstraintOrder::Quadratic) {
105107
throw std::runtime_error("[StateInputConstraintCppAd] Quadratic approximation not supported!");
106108
}
@@ -109,7 +111,7 @@ VectorFunctionQuadraticApproximation StateInputConstraintCppAd::getQuadraticAppr
109111

110112
const size_t stateDim = state.rows();
111113
const size_t inputDim = input.rows();
112-
const vector_t params = getParameters(time);
114+
const vector_t params = getParameters(time, preComputation);
113115
vector_t tapedTimeStateInput(1 + stateDim + inputDim);
114116
tapedTimeStateInput << time, state, input;
115117

ocs2_core/src/cost/StateCostCppAd.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,22 @@ StateCostCppAd::StateCostCppAd(const StateCostCppAd& rhs)
6262
/******************************************************************************************************/
6363
/******************************************************************************************************/
6464
scalar_t StateCostCppAd::getValue(scalar_t time, const vector_t& state, const TargetTrajectories& targetTrajectories,
65-
const PreComputation&) const {
65+
const PreComputation& preComputation) const {
6666
vector_t tapedTimeState(1 + state.rows());
6767
tapedTimeState << time, state;
68-
return adInterfacePtr_->getFunctionValue(tapedTimeState, getParameters(time, targetTrajectories))(0);
68+
return adInterfacePtr_->getFunctionValue(tapedTimeState, getParameters(time, targetTrajectories, preComputation))(0);
6969
}
7070

7171
/******************************************************************************************************/
7272
/******************************************************************************************************/
7373
/******************************************************************************************************/
7474
ScalarFunctionQuadraticApproximation StateCostCppAd::getQuadraticApproximation(scalar_t time, const vector_t& state,
7575
const TargetTrajectories& targetTrajectories,
76-
const PreComputation&) const {
76+
const PreComputation& preComputation) const {
7777
ScalarFunctionQuadraticApproximation cost;
7878

7979
const size_t stateDim = state.rows();
80-
const vector_t params = getParameters(time, targetTrajectories);
80+
const vector_t params = getParameters(time, targetTrajectories, preComputation);
8181
vector_t tapedTimeState(1 + stateDim);
8282
tapedTimeState << time, state;
8383

ocs2_core/src/cost/StateInputCostCppAd.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ StateInputCostCppAd::StateInputCostCppAd(const StateInputCostCppAd& rhs)
6363
/******************************************************************************************************/
6464
/******************************************************************************************************/
6565
scalar_t StateInputCostCppAd::getValue(scalar_t time, const vector_t& state, const vector_t& input,
66-
const TargetTrajectories& targetTrajectories, const PreComputation&) const {
66+
const TargetTrajectories& targetTrajectories, const PreComputation& preComputation) const {
6767
vector_t tapedTimeStateInput(1 + state.rows() + input.rows());
6868
tapedTimeStateInput << time, state, input;
69-
return adInterfacePtr_->getFunctionValue(tapedTimeStateInput, getParameters(time, targetTrajectories))(0);
69+
return adInterfacePtr_->getFunctionValue(tapedTimeStateInput, getParameters(time, targetTrajectories, preComputation))(0);
7070
}
7171

7272
/******************************************************************************************************/
@@ -75,12 +75,12 @@ scalar_t StateInputCostCppAd::getValue(scalar_t time, const vector_t& state, con
7575
ScalarFunctionQuadraticApproximation StateInputCostCppAd::getQuadraticApproximation(scalar_t time, const vector_t& state,
7676
const vector_t& input,
7777
const TargetTrajectories& targetTrajectories,
78-
const PreComputation&) const {
78+
const PreComputation& preComputation) const {
7979
ScalarFunctionQuadraticApproximation cost;
8080

8181
const size_t stateDim = state.rows();
8282
const size_t inputDim = input.rows();
83-
const vector_t params = getParameters(time, targetTrajectories);
83+
const vector_t params = getParameters(time, targetTrajectories, preComputation);
8484
vector_t tapedTimeStateInput(1 + stateDim + inputDim);
8585
tapedTimeStateInput << time, state, input;
8686

0 commit comments

Comments
 (0)