@@ -164,9 +164,15 @@ static inline const mjtNum* mj_stateElemConstPtr(const mjModel* m, const mjData*
164164
165165
166166// get size of state signature
167- int mj_stateSize (const mjModel * m , unsigned int sig ) {
167+ int mj_stateSize (const mjModel * m , int sig ) {
168+ if (sig < 0 ) {
169+ mjERROR ("invalid state signature %d < 0" , sig );
170+ return 0 ;
171+ }
172+
168173 if (sig >= (1 <<mjNSTATE )) {
169- mjERROR ("invalid state signature %u >= 2^mjNSTATE" , sig );
174+ mjERROR ("invalid state signature %d >= 2^mjNSTATE" , sig );
175+ return 0 ;
170176 }
171177
172178 int size = 0 ;
@@ -182,9 +188,15 @@ int mj_stateSize(const mjModel* m, unsigned int sig) {
182188
183189
184190// get state
185- void mj_getState (const mjModel * m , const mjData * d , mjtNum * state , unsigned int sig ) {
191+ void mj_getState (const mjModel * m , const mjData * d , mjtNum * state , int sig ) {
192+ if (sig < 0 ) {
193+ mjERROR ("invalid state signature %d < 0" , sig );
194+ return ;
195+ }
196+
186197 if (sig >= (1 <<mjNSTATE )) {
187- mjERROR ("invalid state signature %u >= 2^mjNSTATE" , sig );
198+ mjERROR ("invalid state signature %d >= 2^mjNSTATE" , sig );
199+ return ;
188200 }
189201
190202 int adr = 0 ;
@@ -213,8 +225,17 @@ void mj_getState(const mjModel* m, const mjData* d, mjtNum* state, unsigned int
213225
214226
215227// extract a sub-state from a state
216- void mj_extractState (const mjModel * m , const mjtNum * src , unsigned int srcsig ,
217- mjtNum * dst , unsigned int dstsig ) {
228+ void mj_extractState (const mjModel * m , const mjtNum * src , int srcsig , mjtNum * dst , int dstsig ) {
229+ if (srcsig < 0 ) {
230+ mjERROR ("invalid srcsig %d < 0" , srcsig );
231+ return ;
232+ }
233+
234+ if (srcsig >= (1 <<mjNSTATE )) {
235+ mjERROR ("invalid srcsig %d >= 2^mjNSTATE" , srcsig );
236+ return ;
237+ }
238+
218239 if ((srcsig & dstsig ) != dstsig ) {
219240 mjERROR ("dstsig is not a subset of srcsig" );
220241 return ;
@@ -235,9 +256,15 @@ void mj_extractState(const mjModel* m, const mjtNum* src, unsigned int srcsig,
235256
236257
237258// set state
238- void mj_setState (const mjModel * m , mjData * d , const mjtNum * state , unsigned int sig ) {
259+ void mj_setState (const mjModel * m , mjData * d , const mjtNum * state , int sig ) {
260+ if (sig < 0 ) {
261+ mjERROR ("invalid state signature %d < 0" , sig );
262+ return ;
263+ }
264+
239265 if (sig >= (1 <<mjNSTATE )) {
240- mjERROR ("invalid state signature %u >= 2^mjNSTATE" , sig );
266+ mjERROR ("invalid state signature %d >= 2^mjNSTATE" , sig );
267+ return ;
241268 }
242269
243270 int adr = 0 ;
@@ -266,9 +293,15 @@ void mj_setState(const mjModel* m, mjData* d, const mjtNum* state, unsigned int
266293
267294
268295// copy state from src to dst
269- void mj_copyState (const mjModel * m , const mjData * src , mjData * dst , unsigned int sig ) {
296+ void mj_copyState (const mjModel * m , const mjData * src , mjData * dst , int sig ) {
297+ if (sig < 0 ) {
298+ mjERROR ("invalid state signature %d < 0" , sig );
299+ return ;
300+ }
301+
270302 if (sig >= (1 <<mjNSTATE )) {
271- mjERROR ("invalid state signature %u >= 2^mjNSTATE" , sig );
303+ mjERROR ("invalid state signature %d >= 2^mjNSTATE" , sig );
304+ return ;
272305 }
273306
274307 for (int i = 0 ; i < mjNSTATE ; i ++ ) {
0 commit comments