Status: Not yet implemented — nice-to-have for documentation & runtime shape checking.
jaxtyping provides shape-annotated tensor types that work with PyTorch. Combined with beartype, it validates shapes at runtime (e.g. during tests).
- Self-documenting shapes —
Float[Tensor, "batch obs_dim"]beatstorch.Tensor+ a comment - Catches shape bugs at call boundaries — mismatched dimensions fail immediately with a clear error
- Named dimensions —
"batch"must match across all arguments, enforcing consistency - Zero production overhead — checking is only active when decorated with
@jaxtyped
from jaxtyping import Float
from torch import Tensor
# Shapes are part of the signature — no comments needed
def forward(
self, obs: Float[Tensor, "batch obs_dim"]
) -> tuple[Float[Tensor, "batch action_dim"], Float[Tensor, " batch"]]:
...Priority files (core data flow, most shape-sensitive):
roboro/core/types.py—Batchfields (obs,actions,rewards, etc.)roboro/actors/*.py—act()andforward()signaturesroboro/critics/*.py—forward()signatures (discrete vs continuous shapes)roboro/updates/*.py— local variables inupdate()for clarityroboro/nn/blocks.py—MLPBlock.forward()
Lower priority (less shape-sensitive):
roboro/data/replay_buffer.pyroboro/encoders/*.pyroboro/training/trainer.py
pip install jaxtyping beartypeTo enable runtime shape validation during testing, decorate functions:
from beartype import beartype
from jaxtyping import jaxtyped
@jaxtyped(typechecker=beartype)
def forward(self, obs: Float[Tensor, "batch obs_dim"]) -> Float[Tensor, "batch n_actions"]:
...Or apply globally in conftest.py for test runs only.
from __future__ import annotationswas removed from the codebase (Python 3.12 doesn't need it), which unblocks jaxtyping's runtime inspection.- jaxtyping works purely as documentation even without beartype — shapes show up in IDE hover and function signatures.
- The
" batch"syntax (leading space) means a single named dimension;"batch dim"means two dimensions.