@@ -48,40 +48,119 @@ void ModelDerivatives::Compute(const mjModel* m,
4848 const double * h, int dim_state,
4949 int dim_state_derivative, int dim_action,
5050 int dim_sensor, int T, double tol, int mode,
51- ThreadPool& pool) {
52- {
53- int count_before = pool.GetCount ();
54- for (int t = 0 ; t < T; t++) {
55- pool.Schedule ([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &h,
56- dim_state, dim_state_derivative, dim_action, dim_sensor,
57- tol, mode, t, T]() {
58- mjData* d = data[ThreadPool::WorkerId ()].get ();
59- // set state
60- SetState (m, d, x + t * dim_state);
61- d->time = h[t];
62-
63- // set action
64- mju_copy (d->ctrl , u + t * dim_action, dim_action);
51+ ThreadPool& pool, int skip) {
52+ // reset indices
53+ evaluate_.clear ();
54+ interpolate_.clear ();
6555
66- // Jacobians
67- if (t == T - 1 ) {
68- // Jacobians
69- mjd_transitionFD (m, d, tol, mode, nullptr , nullptr ,
70- DataAt (C, t * (dim_sensor * dim_state_derivative)),
71- nullptr );
72- } else {
73- // derivatives
74- mjd_transitionFD (
75- m, d, tol, mode,
76- DataAt (A, t * (dim_state_derivative * dim_state_derivative)),
77- DataAt (B, t * (dim_state_derivative * dim_action)),
78- DataAt (C, t * (dim_sensor * dim_state_derivative)),
79- DataAt (D, t * (dim_sensor * dim_action)));
80- }
81- });
56+ // evaluate indices
57+ int s = skip + 1 ;
58+ evaluate_.push_back (0 );
59+ for (int t = s; t < T - s; t += s) {
60+ evaluate_.push_back (t);
61+ }
62+ evaluate_.push_back (T - 2 );
63+ evaluate_.push_back (T - 1 );
64+
65+ // interpolate indices
66+ for (int t = 0 , e = 0 ; t < T; t++) {
67+ if (e == evaluate_.size () || evaluate_[e] > t) {
68+ interpolate_.push_back (t);
69+ } else {
70+ e++;
8271 }
83- pool.WaitCount (count_before + T);
8472 }
73+
74+ // evaluate derivatives
75+ int count_before = pool.GetCount ();
76+ for (int t : evaluate_) {
77+ pool.Schedule ([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &h,
78+ dim_state, dim_state_derivative, dim_action, dim_sensor, tol,
79+ mode, t, T]() {
80+ mjData* d = data[ThreadPool::WorkerId ()].get ();
81+ // set state
82+ SetState (m, d, x + t * dim_state);
83+ d->time = h[t];
84+
85+ // set action
86+ mju_copy (d->ctrl , u + t * dim_action, dim_action);
87+
88+ // Jacobians
89+ if (t == T - 1 ) {
90+ // Jacobians
91+ mjd_transitionFD (m, d, tol, mode, nullptr , nullptr ,
92+ DataAt (C, t * (dim_sensor * dim_state_derivative)),
93+ nullptr );
94+ } else {
95+ // derivatives
96+ mjd_transitionFD (
97+ m, d, tol, mode,
98+ DataAt (A, t * (dim_state_derivative * dim_state_derivative)),
99+ DataAt (B, t * (dim_state_derivative * dim_action)),
100+ DataAt (C, t * (dim_sensor * dim_state_derivative)),
101+ DataAt (D, t * (dim_sensor * dim_action)));
102+ }
103+ });
104+ }
105+ pool.WaitCount (count_before + evaluate_.size ());
106+ pool.ResetCount ();
107+
108+ // interpolate derivatives
109+ count_before = pool.GetCount ();
110+ for (int t : interpolate_) {
111+ pool.Schedule ([&A = A, &B = B, &C = C, &D = D, &evaluate_ = this ->evaluate_ ,
112+ dim_state_derivative, dim_action, dim_sensor, t]() {
113+ // find interval
114+ int bounds[2 ];
115+ FindInterval (bounds, evaluate_, t, evaluate_.size ());
116+ int e0 = evaluate_[bounds[0 ]];
117+ int e1 = evaluate_[bounds[1 ]];
118+
119+ // normalized input
120+ double tt = double (t - e0 ) / double (e1 - e0 );
121+ if (bounds[0 ] == bounds[1 ]) {
122+ tt = 0.0 ;
123+ }
124+
125+ // A
126+ int nA = dim_state_derivative * dim_state_derivative;
127+ double * Ai = DataAt (A, t * nA);
128+ const double * AL = DataAt (A, e0 * nA);
129+ const double * AU = DataAt (A, e1 * nA);
130+
131+ mju_scl (Ai, AL, 1.0 - tt, nA);
132+ mju_addToScl (Ai, AU, tt, nA);
133+
134+ // B
135+ int nB = dim_state_derivative * dim_action;
136+ double * Bi = DataAt (B, t * nB);
137+ const double * BL = DataAt (B, e0 * nB);
138+ const double * BU = DataAt (B, e1 * nB);
139+
140+ mju_scl (Bi, BL, 1.0 - tt, nB);
141+ mju_addToScl (Bi, BU, tt, nB);
142+
143+ // C
144+ int nC = dim_sensor * dim_state_derivative;
145+ double * Ci = DataAt (C, t * nC);
146+ const double * CL = DataAt (C, e0 * nC);
147+ const double * CU = DataAt (C, e1 * nC);
148+
149+ mju_scl (Ci, CL, 1.0 - tt, nC);
150+ mju_addToScl (Ci, CU, tt, nC);
151+
152+ // D
153+ int nD = dim_sensor * dim_action;
154+ double * Di = DataAt (D, t * nD);
155+ const double * DL = DataAt (D, e0 * nD);
156+ const double * DU = DataAt (D, e1 * nD);
157+
158+ mju_scl (Di, DL, 1.0 - tt, nD);
159+ mju_addToScl (Di, DU, tt, nD);
160+ });
161+ }
162+
163+ pool.WaitCount (count_before + interpolate_.size ());
85164 pool.ResetCount ();
86165}
87166
0 commit comments