Skip to content

Latest commit

 

History

History
72 lines (49 loc) · 2.45 KB

File metadata and controls

72 lines (49 loc) · 2.45 KB

Development Notes

Optional: Tensor Shape Annotations with jaxtyping + beartype

Status: Not yet implemented — nice-to-have for documentation & runtime shape checking.

What

jaxtyping provides shape-annotated tensor types that work with PyTorch. Combined with beartype, it validates shapes at runtime (e.g. during tests).

Why

  • Self-documenting shapesFloat[Tensor, "batch obs_dim"] beats torch.Tensor + a comment
  • Catches shape bugs at call boundaries — mismatched dimensions fail immediately with a clear error
  • Named dimensions"batch" must match across all arguments, enforcing consistency
  • Zero production overhead — checking is only active when decorated with @jaxtyped

Example (SquashedGaussianActor)

from jaxtyping import Float
from torch import Tensor

# Shapes are part of the signature — no comments needed
def forward(
    self, obs: Float[Tensor, "batch obs_dim"]
) -> tuple[Float[Tensor, "batch action_dim"], Float[Tensor, " batch"]]:
    ...

Where to apply

Priority files (core data flow, most shape-sensitive):

  1. roboro/core/types.pyBatch fields (obs, actions, rewards, etc.)
  2. roboro/actors/*.pyact() and forward() signatures
  3. roboro/critics/*.pyforward() signatures (discrete vs continuous shapes)
  4. roboro/updates/*.py — local variables in update() for clarity
  5. roboro/nn/blocks.pyMLPBlock.forward()

Lower priority (less shape-sensitive):

  1. roboro/data/replay_buffer.py
  2. roboro/encoders/*.py
  3. roboro/training/trainer.py

Install

pip install jaxtyping beartype

Runtime checking (optional, for tests)

To enable runtime shape validation during testing, decorate functions:

from beartype import beartype
from jaxtyping import jaxtyped

@jaxtyped(typechecker=beartype)
def forward(self, obs: Float[Tensor, "batch obs_dim"]) -> Float[Tensor, "batch n_actions"]:
    ...

Or apply globally in conftest.py for test runs only.

Notes

  • from __future__ import annotations was removed from the codebase (Python 3.12 doesn't need it), which unblocks jaxtyping's runtime inspection.
  • jaxtyping works purely as documentation even without beartype — shapes show up in IDE hover and function signatures.
  • The " batch" syntax (leading space) means a single named dimension; "batch dim" means two dimensions.