Skip to content

Commit f44bc84

Browse files
committed
Improve rototranslate and get_body_model APIs
Make rototranslate t default to zero for pure-rotation use case. Remove redundant DATA_ROOT/model_root defaulting from get_body_model (now handled in BodyModel.__init__).
1 parent 9685b8a commit f44bc84

File tree

10 files changed

+54
-33
lines changed

10 files changed

+54
-33
lines changed

src/smplfitter/jax/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import functools
6-
import os
76

87
from .bodymodel import BodyModel
98
from .bodyfitter import BodyFitter
@@ -21,7 +20,4 @@ def get_cached_body_model(model_name='smpl', gender='neutral', model_root=None):
2120

2221

2322
def get_body_model(model_name, gender, model_root=None):
24-
if model_root is None:
25-
DATA_ROOT = os.getenv('DATA_ROOT', default='.')
26-
model_root = f'{DATA_ROOT}/body_models/{model_name}'
2723
return BodyModel(model_root=model_root, gender=gender, model_name=model_name)

src/smplfitter/jax/bodymodel.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,23 @@ def __call__(
156156
)
157157

158158
def rototranslate(
159-
self, R, t, pose_rotvecs, shape_betas, trans, kid_factor=0, post_translate=True
159+
self,
160+
R,
161+
t=None,
162+
pose_rotvecs=None,
163+
shape_betas=None,
164+
trans=None,
165+
kid_factor=0,
166+
post_translate=True,
160167
):
161168
"""Rotate and translate the body in parametric form.
162169
163170
See np.BodyModel.rototranslate for full documentation.
164171
"""
172+
if t is None:
173+
t = jnp.zeros(3, dtype=R.dtype)
174+
if pose_rotvecs is None or shape_betas is None or trans is None:
175+
raise ValueError('pose_rotvecs, shape_betas, and trans are required.')
165176
current_rotmat = rotvec2mat(pose_rotvecs[:3])
166177
new_rotmat = R @ current_rotmat
167178
new_pose_rotvec = jnp.concatenate([mat2rotvec(new_rotmat), pose_rotvecs[3:]], axis=0)

src/smplfitter/nb/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import functools
6-
import os
76

87
from .bodymodel import BodyModel
98
from .bodyfitter import BodyFitter
@@ -21,7 +20,4 @@ def get_cached_body_model(model_name='smpl', gender='neutral', model_root=None):
2120

2221

2322
def get_body_model(model_name, gender, model_root=None):
24-
if model_root is None:
25-
DATA_ROOT = os.getenv('DATA_ROOT', default='.')
26-
model_root = f'{DATA_ROOT}/body_models/{model_name}'
2723
return BodyModel(model_root=model_root, gender=gender, model_name=model_name)

src/smplfitter/nb/bodymodel.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,14 @@ def single(self, *args, return_vertices=True, **kwargs):
223223
return {k: np.squeeze(v, axis=0) for k, v in result.items()}
224224

225225
def rototranslate(
226-
self, R, t, pose_rotvecs, shape_betas, trans, kid_factor=0, post_translate=True
226+
self,
227+
R,
228+
t=None,
229+
pose_rotvecs=None,
230+
shape_betas=None,
231+
trans=None,
232+
kid_factor=0,
233+
post_translate=True,
227234
) -> tuple[np.ndarray, np.ndarray]:
228235
"""
229236
Rotates and translates the body in parametric form.
@@ -239,7 +246,7 @@ def rototranslate(
239246
240247
Parameters:
241248
R: Rotation matrix, shaped as (3, 3).
242-
t: Translation vector, shaped as (3,).
249+
t: Translation vector, shaped as (3,). Defaults to zero (pure rotation).
243250
pose_rotvecs: Initial rotation vectors per joint, shaped as (num_joints * 3,).
244251
shape_betas: Shape coefficients (betas) for body shape, shaped as (num_betas,).
245252
trans: Initial translation vector, shaped as (3,).
@@ -261,6 +268,10 @@ def rototranslate(
261268
account the offset between the pelvis joint in the shaped T-pose and the origin of
262269
the canonical coordinate system.
263270
"""
271+
if t is None:
272+
t = np.zeros(3, dtype=R.dtype)
273+
if pose_rotvecs is None or shape_betas is None or trans is None:
274+
raise ValueError('pose_rotvecs, shape_betas, and trans are required.')
264275
current_rotmat = rotvec2mat(pose_rotvecs[:3])
265276
new_rotmat = R @ current_rotmat
266277
new_pose_rotvec = np.concatenate([mat2rotvec(new_rotmat), pose_rotvecs[3:]], axis=0)

src/smplfitter/np/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from smplfitter.common import _set_module_for_docs
99

1010
import functools
11-
import os
1211

1312
__all__ = ['BodyModel', 'BodyFitter', 'BodyConverter', 'get_cached_body_model']
1413
_set_module_for_docs(__name__, globals(), __all__)
@@ -20,7 +19,4 @@ def get_cached_body_model(model_name='smpl', gender='neutral', model_root=None):
2019

2120

2221
def get_body_model(model_name, gender, model_root=None):
23-
if model_root is None:
24-
DATA_ROOT = os.getenv('DATA_ROOT', default='.')
25-
model_root = f'{DATA_ROOT}/body_models/{model_name}'
2622
return BodyModel(model_root=model_root, gender=gender, model_name=model_name)

src/smplfitter/np/bodymodel.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,14 @@ def single(self, *args, return_vertices=True, **kwargs):
232232
return {k: np.squeeze(v, axis=0) for k, v in result.items()}
233233

234234
def rototranslate(
235-
self, R, t, pose_rotvecs, shape_betas, trans, kid_factor=0, post_translate=True
235+
self,
236+
R,
237+
t=None,
238+
pose_rotvecs=None,
239+
shape_betas=None,
240+
trans=None,
241+
kid_factor=0,
242+
post_translate=True,
236243
) -> tuple[np.ndarray, np.ndarray]:
237244
"""
238245
Rotates and translates the body in parametric form.
@@ -248,7 +255,7 @@ def rototranslate(
248255
249256
Parameters:
250257
R: Rotation matrix, shaped as (3, 3).
251-
t: Translation vector, shaped as (3,).
258+
t: Translation vector, shaped as (3,). Defaults to zero (pure rotation).
252259
pose_rotvecs: Initial rotation vectors per joint, shaped as (num_joints * 3,).
253260
shape_betas: Shape coefficients (betas) for body shape, shaped as (num_betas,).
254261
trans: Initial translation vector, shaped as (3,).
@@ -270,6 +277,10 @@ def rototranslate(
270277
account the offset between the pelvis joint in the shaped T-pose and the origin of
271278
the canonical coordinate system.
272279
"""
280+
if t is None:
281+
t = np.zeros(3, dtype=R.dtype)
282+
if pose_rotvecs is None or shape_betas is None or trans is None:
283+
raise ValueError('pose_rotvecs, shape_betas, and trans are required.')
273284
current_rotmat = rotvec2mat(pose_rotvecs[:3])
274285
new_rotmat = R @ current_rotmat
275286
new_pose_rotvec = np.concatenate([mat2rotvec(new_rotmat), pose_rotvecs[3:]], axis=0)

src/smplfitter/pt/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Optional
66

77
import functools
8-
import os
98
import torch
109
import warnings
1110
from .bodymodel import BodyModel
@@ -33,9 +32,6 @@ def get_cached_body_model(model_name='smpl', gender='neutral', model_root=None):
3332

3433

3534
def get_body_model(model_name, gender, model_root=None):
36-
if model_root is None:
37-
DATA_ROOT = os.getenv('DATA_ROOT', default='.')
38-
model_root = f'{DATA_ROOT}/body_models/{model_name}'
3935
return BodyModel(model_root=model_root, gender=gender, model_name=model_name)
4036

4137

src/smplfitter/pt/bodymodel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,10 @@ def single(
338338
def rototranslate(
339339
self,
340340
R: torch.Tensor,
341-
t: torch.Tensor,
342-
pose_rotvecs: torch.Tensor,
343-
shape_betas: torch.Tensor,
344-
trans: torch.Tensor,
341+
t: Optional[torch.Tensor] = None,
342+
pose_rotvecs: Optional[torch.Tensor] = None,
343+
shape_betas: Optional[torch.Tensor] = None,
344+
trans: Optional[torch.Tensor] = None,
345345
kid_factor: Optional[torch.Tensor] = None,
346346
post_translate: bool = True,
347347
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -387,6 +387,10 @@ def rototranslate(
387387
388388
"""
389389

390+
if t is None:
391+
t = torch.zeros(3, device=R.device, dtype=R.dtype)
392+
if pose_rotvecs is None or shape_betas is None or trans is None:
393+
raise ValueError('pose_rotvecs, shape_betas, and trans are required.')
390394
current_rotmat = rotvec2mat(pose_rotvecs[:3])
391395
new_rotmat = R @ current_rotmat
392396
new_pose_rotvec = torch.cat([mat2rotvec(new_rotmat), pose_rotvecs[3:]], dim=0)

src/smplfitter/tf/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import functools
6-
import os
76

87
import tensorflow as tf
98
from .bodymodel import BodyModel
@@ -27,9 +26,6 @@ def get_cached_body_model(model_name='smpl', gender='neutral', model_root=None):
2726

2827

2928
def get_body_model(model_name, gender, model_root=None, num_betas=None, vertex_subset=None):
30-
if model_root is None:
31-
DATA_ROOT = os.getenv('DATA_ROOT', '.')
32-
model_root = f'{DATA_ROOT}/body_models/{model_name}'
3329
return BodyModel(
3430
model_root=model_root,
3531
gender=gender,

src/smplfitter/tf/bodymodel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,10 @@ def single(self, *args, return_vertices=True, **kwargs):
249249
def rototranslate(
250250
self,
251251
R: tf.Tensor,
252-
t: tf.Tensor,
253-
pose_rotvecs: tf.Tensor,
254-
shape_betas: tf.Tensor,
255-
trans: tf.Tensor,
252+
t=None,
253+
pose_rotvecs=None,
254+
shape_betas=None,
255+
trans=None,
256256
kid_factor=0,
257257
post_translate: bool = True,
258258
):
@@ -291,6 +291,10 @@ def rototranslate(
291291
account the offset between the pelvis joint in the shaped T-pose and the origin of
292292
the canonical coordinate system.
293293
"""
294+
if t is None:
295+
t = tf.zeros(3, dtype=R.dtype)
296+
if pose_rotvecs is None or shape_betas is None or trans is None:
297+
raise ValueError('pose_rotvecs, shape_betas, and trans are required.')
294298
current_rotmat = rotvec2mat(pose_rotvecs[..., :3])
295299
new_rotmat = R @ current_rotmat
296300
new_pose_rotvec = tf.concat([mat2rotvec(new_rotmat), pose_rotvecs[3:]], axis=0)

0 commit comments

Comments
 (0)