diff --git a/python/mujoco/experimental/mjwarp/sim.py b/python/mujoco/experimental/mjwarp/sim.py new file mode 100644 index 0000000000..ffb08d08b1 --- /dev/null +++ b/python/mujoco/experimental/mjwarp/sim.py @@ -0,0 +1,165 @@ +# Copyright 2025 Kevin Zakka +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal, cast + +import mujoco +import mujoco_warp as mjwarp +import numpy as np +import warp as wp + +from mujoco.experimental.mjwarp.sim_data import WarpBridge + +HEIGHT = 240 +WIDTH = 320 +LS_PARALLEL = False + +# Type aliases for better IDE support while maintaining runtime compatibility +# At runtime, WarpBridge wraps the actual MJWarp objects. +if TYPE_CHECKING: + ModelBridge = mjwarp.Model + DataBridge = mjwarp.Data +else: + ModelBridge = WarpBridge + DataBridge = WarpBridge + + +class Simulation: + """GPU-accelerated MuJoCo simulation powered by MJWarp.""" + + def __init__( + self, num_envs: int, nconmax: int, njmax: int, model: mujoco.MjModel, device: str + ): + self.device = device + self.wp_device = wp.get_device(self.device) + self.num_envs = num_envs + + self._mj_model = model + self._mj_data = mujoco.MjData(model) + mujoco.mj_forward(self._mj_model, self._mj_data) + + with wp.ScopedDevice(self.wp_device): + self._wp_model = mjwarp.put_model(self._mj_model) + self._wp_model.opt.ls_parallel = LS_PARALLEL + + self._wp_data = mjwarp.put_data( + self._mj_model, + self._mj_data, + nworld=self.num_envs, + nconmax=nconmax, + njmax=njmax, + ) + + self._model_bridge = WarpBridge(self._wp_model, nworld=self.num_envs) + self._data_bridge = WarpBridge(self._wp_data) + + self.use_cuda_graph = self.wp_device.is_cuda and wp.is_mempool_enabled( + self.wp_device + ) + self.create_graph() + + self._camera = -1 + self._renderer: mujoco.Renderer | None = None + + def initialize_renderer(self) -> None: + if self._renderer is not None: + raise RuntimeError( + "Renderer is already initialized. Call 'close()' first to reinitialize." + ) + self._renderer = mujoco.Renderer( + model=self._mj_model, height=HEIGHT, width=WIDTH + ) + + def create_graph(self) -> None: + self.step_graph = None + self.forward_graph = None + if self.use_cuda_graph: + with wp.ScopedCapture() as capture: + mjwarp.step(self.wp_model, self.wp_data) + self.step_graph = capture.graph + with wp.ScopedCapture() as capture: + mjwarp.forward(self.wp_model, self.wp_data) + self.forward_graph = capture.graph + + # Properties. + + @property + def mj_model(self) -> mujoco.MjModel: + return self._mj_model + + @property + def mj_data(self) -> mujoco.MjData: + return self._mj_data + + @property + def wp_model(self) -> mjwarp.Model: + return self._wp_model + + @property + def wp_data(self) -> mjwarp.Data: + return self._wp_data + + @property + def data(self) -> "DataBridge": + return cast("DataBridge", self._data_bridge) + + @property + def model(self) -> "ModelBridge": + return cast("ModelBridge", self._model_bridge) + + @property + def renderer(self) -> mujoco.Renderer: + if self._renderer is None: + raise ValueError("Renderer not initialized. Call 'initialize_renderer()' first.") + + return self._renderer + + # Methods. + def reset(self) -> None: + # TODO(kevin): Should we be doing anything here? + pass + + def forward(self) -> None: + with wp.ScopedDevice(self.wp_device): + if self.use_cuda_graph and self.forward_graph is not None: + wp.capture_launch(self.forward_graph) + else: + mjwarp.forward(self.wp_model, self.wp_data) + + def step(self) -> None: + with wp.ScopedDevice(self.wp_device): + if self.use_cuda_graph and self.step_graph is not None: + wp.capture_launch(self.step_graph) + else: + mjwarp.step(self.wp_model, self.wp_data) + + def update_render(self) -> None: + if self._renderer is None: + raise ValueError("Renderer not initialized. Call 'initialize_renderer()' first.") + + mjwarp.get_data_into(self._mj_data, self._mj_model, self._wp_data) + mujoco.mj_forward(self._mj_model, self._mj_data) + self._renderer.update_scene(data=self._mj_data, camera=self._camera) + + def render(self) -> np.ndarray: + if self._renderer is None: + raise ValueError("Renderer not initialized. Call 'initialize_renderer()' first.") + + return self._renderer.render() + + def close(self) -> None: + if self._renderer is not None: + self._renderer.close() + self._renderer = None diff --git a/python/mujoco/experimental/mjwarp/sim_data.py b/python/mujoco/experimental/mjwarp/sim_data.py new file mode 100644 index 0000000000..9f446a0e3e --- /dev/null +++ b/python/mujoco/experimental/mjwarp/sim_data.py @@ -0,0 +1,243 @@ +# Copyright 2025 Kevin Zakka +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bridge for seamless PyTorch-Warp interoperability with zero-copy memory sharing. + +Provides automatic wrapping of Warp arrays as PyTorch-compatible objects while +preserving shared memory and CUDA graph compatibility. +""" + +from typing import Any, Dict, Generic, Optional, Tuple, TypeVar + +import torch +import warp as wp + +T = TypeVar("T") + + +class TorchArray: + """Warp array that behaves like a torch.Tensor with shared memory.""" + + def __init__(self, wp_array: wp.array, nworld: int | None = None) -> None: + """Initialize the tensor proxy with a Warp array.""" + if ( + nworld is not None + and len(wp_array.shape) > 0 + and wp_array.strides[0] == 0 + and wp_array.shape[0] > nworld + ): + wp_array = wp_array[:nworld] # type: ignore + + self._wp_array = wp_array + self._tensor = wp.to_torch(wp_array) + self._is_cuda = not self._wp_array.device.is_cpu # type: ignore + self._torch_stream = self._setup_stream() + + def _setup_stream(self) -> Optional[torch.cuda.Stream]: + """Setup appropriate stream for the device.""" + if not self._is_cuda: + return None + + try: + warp_stream = wp.get_stream(self._wp_array.device) + return torch.cuda.ExternalStream(warp_stream.cuda_stream) + except Exception as e: + # Fallback to default stream if external stream creation fails. + print(f"Warning: Could not create external stream: {e}") + return torch.cuda.current_stream(self._tensor.device) + + @property + def wp_array(self) -> wp.array: + return self._wp_array + + def __repr__(self) -> str: + """Return string representation of the underlying tensor.""" + return repr(self._tensor) + + def __getitem__(self, idx: Any) -> Any: + """Get item(s) from the tensor using standard indexing.""" + return self._tensor[idx] + + def __setitem__(self, idx: Any, value: Any) -> None: + """Set item(s) in the tensor using standard indexing.""" + if self._is_cuda and self._torch_stream is not None: + with torch.cuda.stream(self._torch_stream): + self._tensor[idx] = value + else: + self._tensor[idx] = value + + def __getattr__(self, name: str) -> Any: + """Delegate attribute access to the underlying tensor.""" + return getattr(self._tensor, name) + + @classmethod + def __torch_function__( + cls, + func: Any, + types: Tuple[type, ...], + args: Tuple[Any, ...] = (), + kwargs: Optional[Dict[str, Any]] = None, + ) -> Any: + """Intercept torch.* function calls to unwrap TorchArray objects.""" + if kwargs is None: + kwargs = {} + + # Only intercept when at least one argument is our proxy. + if not any(issubclass(t, cls) for t in types): + return NotImplemented + + def _unwrap(x: Any) -> Any: + """Unwrap TorchArray objects to their underlying tensors.""" + return x._tensor if isinstance(x, cls) else x + + # Unwrap all TorchArray objects in args and kwargs. + unwrapped_args = tuple(_unwrap(arg) for arg in args) + unwrapped_kwargs = {k: _unwrap(v) for k, v in kwargs.items()} + + return func(*unwrapped_args, **unwrapped_kwargs) + + # Arithmetic operators. + + def __add__(self, other: Any) -> Any: + return self._tensor + other + + def __radd__(self, other: Any) -> Any: + return other + self._tensor + + def __sub__(self, other: Any) -> Any: + return self._tensor - other + + def __rsub__(self, other: Any) -> Any: + return other - self._tensor + + def __mul__(self, other: Any) -> Any: + return self._tensor * other + + def __rmul__(self, other: Any) -> Any: + return other * self._tensor + + def __truediv__(self, other: Any) -> Any: + return self._tensor / other + + def __rtruediv__(self, other: Any) -> Any: + return other / self._tensor + + def __pow__(self, other: Any) -> Any: + return self._tensor**other + + def __rpow__(self, other: Any) -> Any: + return other**self._tensor + + def __neg__(self) -> Any: + return -self._tensor + + def __pos__(self) -> Any: + return +self._tensor + + def __abs__(self) -> Any: + return abs(self._tensor) + + # Comparison operators. + + def __eq__(self, other: Any) -> Any: + return self._tensor == other + + def __ne__(self, other: Any) -> Any: + return self._tensor != other + + def __lt__(self, other: Any) -> Any: + return self._tensor < other + + def __le__(self, other: Any) -> Any: + return self._tensor <= other + + def __gt__(self, other: Any) -> Any: + return self._tensor > other + + def __ge__(self, other: Any) -> Any: + return self._tensor >= other + + +def _contains_warp_arrays(obj: Any) -> bool: + """Check if an object or its attributes contain any Warp arrays.""" + if isinstance(obj, wp.array): + return True + + # Check if it's a struct-like object with attributes + if hasattr(obj, "__dict__"): + return any( + isinstance(getattr(obj, attr), wp.array) + for attr in dir(obj) + if not attr.startswith("_") + ) + + return False + + +class WarpBridge(Generic[T]): + """Wraps mjwarp objects to expose Warp arrays as PyTorch tensors. + + Automatically converts Warp array attributes to TorchArray objects + on access, enabling direct PyTorch operations on simulation data. + Recursively wraps nested structures that contain Warp arrays. + + IMPORTANT: This wrapper is read-only. To modify array data, use + in-place operations like `obj.field[:] = value`. Direct assignment + like `obj.field = new_array` will raise an AttributeError to prevent + accidental memory address changes that break CUDA graphs. + """ + + def __init__(self, struct: T, nworld: int | None = None) -> None: + object.__setattr__(self, "_struct", struct) + object.__setattr__(self, "_wrapped_cache", {}) + object.__setattr__(self, "_nworld", nworld) + + def __getattr__(self, name: str) -> Any: + """Get attribute from the wrapped data, wrapping Warp arrays as TorchArray.""" + # Check cache first to avoid recreating wrappers. + if name in self._wrapped_cache: + return self._wrapped_cache[name] + + val = getattr(self._struct, name) + + # Wrap Warp arrays. + if isinstance(val, wp.array): + wrapped = TorchArray(val, nworld=self._nworld) + self._wrapped_cache[name] = wrapped + return wrapped + + # Recursively wrap nested structures that contain Warp arrays. + if _contains_warp_arrays(val): + wrapped = WarpBridge(val, nworld=self._nworld) + self._wrapped_cache[name] = wrapped + return wrapped + + return val + + def __setattr__(self, name: str, value: Any) -> None: + """Prevent attribute setting to maintain CUDA graph safety.""" + raise AttributeError( + f"Cannot set attribute '{name}' on WarpBridge. " + f"This wrapper is read-only to preserve memory addresses for CUDA graphs. " + f"Use in-place operations instead: obj.{name}[:] = value" + ) + + def __repr__(self) -> str: + """Return string representation of the wrapped struct.""" + return f"WarpBridge({repr(self._struct)})" + + @property + def struct(self) -> T: + """Access the underlying wrapped struct.""" + return self._struct