Skip to content

v1.12.0

Latest

Choose a tag to compare

@github-actions github-actions released this 06 Mar 20:21
· 79 commits to main since this release
v1.12.0

Warp v1.12.0

Warp v1.12 adds experimental hardware-accelerated texture sampling on CUDA GPUs, extends tile programming with element-wise arithmetic operators and differentiable FFT, and broadens JAX interoperability with jax.vmap support. This release also introduces subscript-style type hints for better IDE integration, new quaternion and approximate-math builtins, B-spline shape functions in warp.fem, and a collection of utility and diagnostics APIs.

New features

Hardware-accelerated textures

Experimental. This API may change without a formal deprecation cycle.

Warp v1.12 introduces wp.Texture1D, wp.Texture2D, and wp.Texture3D classes that leverage CUDA texture memory for hardware-accelerated interpolation directly inside Warp kernels. On GPU, texture reads are routed through dedicated texture units that perform filtered lookups in a single instruction, making them ideal for rendering, volume sampling, signed-distance-field queries, and simulation lookup tables. On CPU, a software fallback provides identical semantics so the same kernel code runs on both devices.

import warp as wp
import numpy as np

wp.init()

# 64x64 single-channel height map
data = np.random.rand(64, 64).astype(np.float32)

# Create a 2D texture with bilinear filtering
tex = wp.Texture2D(data, filter_mode=wp.Texture.FILTER_LINEAR)

@wp.kernel
def sample_texture(tex: wp.Texture2D, coords: wp.array[wp.vec2f], out: wp.array[float]):
    i = wp.tid()
    # Coordinates are in [0, 1]; bilinear interpolation is automatic
    out[i] = wp.texture_sample(tex, coords[i], dtype=float)

coords = wp.array(np.random.rand(1024, 2).astype(np.float32), dtype=wp.vec2f)
result = wp.zeros(1024, dtype=float)
wp.launch(sample_texture, dim=1024, inputs=[tex, coords, result])

print(f"Sampled {result.shape[0]} points, range: [{result.numpy().min():.4f}, {result.numpy().max():.4f}]")
# Example output: Sampled 1024 points, range: [0.0069, 0.9793]

Key capabilities:

  • 1D / 2D / 3D texture classes (wp.Texture1D, wp.Texture2D, wp.Texture3D) with matching wp.texture_sample() overloads that accept scalar, vec2f, or vec3f coordinates.
  • Filter modes: FILTER_POINT for nearest-neighbor sampling and FILTER_LINEAR for bilinear (2D) or trilinear (3D) interpolation.
  • Address modes: ADDRESS_WRAP, ADDRESS_CLAMP, ADDRESS_MIRROR, and ADDRESS_BORDER control how out-of-range texture coordinates are handled, configurable per axis.
  • Array interop: Texture objects provide copy_from_array() and copy_to_array() methods to transfer data between wp.array objects and texture memory. A cuda_surface property exposes the CUDA surface handle for advanced interop.
  • Broad dtype support: Textures accept integer and floating-point data types with 1, 2, or 4 channels. Integer types are automatically normalized to floating-point values on read.

Subscript-style type hints

When annotating kernel parameters with call-syntax forms like wp.array(dtype=float), static type checkers such as Pyright and Pylance flag these as errors because the expressions look like constructor calls rather than type annotations. Warp v1.12 adds subscript-style alternatives that are recognized as valid generic aliases (#1216):

# Before (flagged as error by Pyright/Pylance):
@wp.kernel
def my_kernel(a: wp.array(dtype=float), b: wp.array2d(dtype=wp.vec3)):
    ...

# After (clean subscript syntax):
@wp.kernel
def my_kernel(a: wp.array[float], b: wp.array2d[wp.vec3]):
    ...

The subscript syntax is supported for all array dimensionalities (wp.array[dtype] through wp.array4d[dtype]) as well as wp.tile[dtype] for tile-typed arguments.

Warp's static type checking compatibility is being improved incrementally, and you may encounter other Pyright/Pylance diagnostics that are not yet resolved. If you run into type checking issues, please report them as sub-issues of #549.

Diagnostics utility

The new wp.print_diagnostics() function displays a comprehensive snapshot of the Warp build and runtime environment (software versions, CUDA information, build flags, and available devices) in a single call (#1221). Two companion helpers, wp.get_cuda_toolkit_version() and wp.get_cuda_driver_version(), return the CUDA toolkit and driver versions as integer tuples (#1172). Together these are useful for debugging environment issues, capturing context in CI logs, and providing system information when filing bug reports.

Quaternion and spatial helpers

Warp v1.12 adds quaternion and spatial transformation helpers: wp.quat_from_euler(), wp.quat_to_euler(), wp.transform_twist(), and wp.transform_wrench() (#1237). The Euler conversion functions accept axis indices (0 = X, 1 = Y, 2 = Z) so you can specify arbitrary rotation-order conventions such as ZYX or XYZ, making them suitable for robotics and animation pipelines:

euler = wp.vec3(0.0, wp.PI / 4.0, 0.0)
q = wp.quat_from_euler(euler, 2, 1, 0)  # ZYX convention
print(q)  # [0.0, 0.3826834559440613, 0.0, 0.9238795042037964]

Approximate math intrinsics

wp.div_approx() and wp.inverse_approx() expose GPU hardware fast-math instructions (div.approx.f32 and rcp.approx.ftz.f64) for approximate floating-point division and reciprocal, offering higher throughput at reduced precision (#1199). Only floating-point types are supported. On CPU, both functions fall back to exact arithmetic so the same kernel code runs correctly on either device.

Marching cubes lookup tables

The internal marching cubes lookup tables are now exposed as public class attributes on wp.MarchingCubes: CUBE_CORNER_OFFSETS, EDGE_TO_CORNERS, CASE_TO_TRI_RANGE, and TRI_LOCAL_INDICES (#1151). These tables enable custom marching cubes implementations for advanced use cases such as sparse volume extraction or procedural mesh generation without having to duplicate the standard lookup data.

Graph coloring API

wp.utils.graph_coloring_assign(), wp.utils.graph_coloring_balance(), and wp.graph_coloring_get_groups() are now part of the public API (#1145). These graph coloring utilities were originally introduced in warp.sim in v1.5.0 for use with VBDIntegrator and were removed along with the warp.sim module in v1.10.0. They are now re-introduced as standalone functions in wp.utils, independent of any physics module. They partition a graph into independent color groups, which is useful for parallel constraint solving, conflict-free mesh updates, and other tasks that require concurrent writes to non-adjacent elements.

Tile programming enhancements

Tile arithmetic operators

Tiles now support native Python * and / operators for element-wise multiplication and division, including broadcast between tiles and scalar constants (#1006, #1009). The supported forms are tile * tile, tile * constant, constant * tile for multiplication, and tile / tile, tile / constant, constant / tile for division. All combinations are differentiable and work with scalar, vector, and matrix element types.

import warp as wp

TILE_SIZE = wp.constant(64)

@wp.kernel
def scale_and_normalize(
    a: wp.array[float],
    b: wp.array[float],
    out: wp.array[float],
):
    i = wp.tid()
    ta = wp.tile_load(a, shape=TILE_SIZE, offset=i * TILE_SIZE)
    tb = wp.tile_load(b, shape=TILE_SIZE, offset=i * TILE_SIZE)

    product = ta * tb          # element-wise multiply
    scaled = product * 0.5     # broadcast scalar multiply
    result = scaled / tb       # element-wise divide

    wp.tile_store(out, result, offset=i * TILE_SIZE)

N = 256
a = wp.ones(N, dtype=float)
b = wp.full(N, value=2.0, dtype=float)
out = wp.zeros(N, dtype=float)
wp.launch_tiled(scale_and_normalize, dim=[N // 64], inputs=[a, b, out], block_dim=64)

print(out.numpy()[:8])  # [0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]

wp.tile_from_thread()

wp.tile_from_thread() broadcasts a scalar or vector value held by a single thread to a shared tile visible to all threads in the block (#1178). This is useful when one thread computes a value (e.g., a reduction result or a loop-invariant parameter) that the entire block needs to use in subsequent tile operations. The function accepts a thread_idx argument to specify which thread's value is broadcast, and supports both "shared" and "register" storage modes.

Differentiable FFT

wp.tile_fft() and wp.tile_ifft() now support reverse-mode automatic differentiation when recorded on a wp.Tape() (#1138). Warp automatically provides the correct gradient implementations for both transforms, so gradients propagate seamlessly through frequency-domain operations. This enables end-to-end gradient computation through pipelines that mix spatial and spectral steps, which is useful for differentiable signal processing, spectral methods, and PDE solvers.

MathDx GEMM toggle

Setting wp.config.enable_mathdx_gemm = False (or passing "enable_mathdx_gemm": False as a module option) disables cuBLASDx for wp.tile_matmul(), falling back to an optimized scalar GEMM implementation (#1228). This avoids the slow link-time optimization (LTO) step required by libmathdx during development iteration, while keeping libmathdx available for operations that have no scalar fallback, such as Cholesky factorization and FFT. The scalar fallback may be slower than cuBLASDx depending on tile sizes, data types, and block_dim, so this option is primarily intended for faster compile–edit–run cycles during development rather than production use.

Accelerated tile load/store

Shared-memory tile loads and stores via wp.tile_load() / wp.tile_store() have been accelerated for non-power-of-two tile sizes (#1239). The improvement is most pronounced when source arrays fit within the GPU L2 cache, reducing overhead for tile-based kernels that operate on irregularly shaped data blocks.

JAX integration

jax.vmap support

jax.vmap() can now be used with Warp kernels and callables exposed through jax_kernel() and jax_callable() (#859). This enables vectorized mapping over batched inputs, letting JAX automatically handle batching of Warp kernel invocations without manual loop constructs. The vmap_method parameter controls how batching is implemented. "broadcast_all" broadcasts all arrays to include the batch dimension, while "sequential" iterates over the batch dimension one element at a time.

import warp as wp
import jax
import jax.numpy as jp
from warp.jax_experimental.ffi import jax_kernel

@wp.kernel
def add_kernel(a: wp.array[float], b: wp.array[float], out: wp.array[float]):
    i = wp.tid()
    out[i] = a[i] + b[i]

# Wrap the Warp kernel for JAX
jax_add = jax_kernel(add_kernel, vmap_method="broadcast_all")

# Batched inputs: 3 arrays of 4 elements each
a = jp.arange(12, dtype=jp.float32).reshape((3, 4))
b = jp.ones((3, 4), dtype=jp.float32)

# Use jax.vmap to apply the kernel over the batch dimension
(result,) = jax.jit(jax.vmap(jax_add))(a, b)
print(result)
# [[ 1.  2.  3.  4.]
#  [ 5.  6.  7.  8.]
#  [ 9. 10. 11. 12.]]

has_side_effect flag

The new has_side_effect flag on jax_kernel() and jax_callable() ensures that JAX does not optimize away Warp FFI calls whose outputs are unused downstream (#1240). This is important for kernels that write to global state or perform I/O as side effects. Without the flag, JAX's dead-code elimination may silently skip their execution. Setting has_side_effect=True marks the call as effectful, forcing JAX to always execute it.

warp.fem enhancements

B-spline shape functions

Warp v1.12 adds B-spline basis functions to warp.fem with SquareBSplineShapeFunctions (2D) and CubeBSplineShapeFunctions (3D), accessible via ElementBasis.BSPLINE (#1208). B-spline bases provide higher continuity across element boundaries compared to standard Lagrange bases, which is beneficial for applications requiring smooth solutions such as thin-shell mechanics or isogeometric analysis. Degrees 1 through 3 are supported on Grid2D, Grid3D, and Nanogrid geometries.

Other improvements

  • The cells() operator now accepts traced fields, returning the underlying cell-level field for evaluation at cell-space samples (e.g., from lookup()).
  • PicQuadrature particles can now span multiple cells by passing a tuple of 2D arrays (cell_indices, coords, particle_fraction) to specify per-particle cell contributions.
  • Deprecation warnings have been added for the quadrature and domain arguments of interpolate(), and the space argument of make_space_restriction and make_space_partition (scheduled for removal in 1.14).

Compilation and tooling

  • NVRTC compilation no longer requires a CUDA driver, enabling wp.compile_aot_module() to produce PTX/CUBIN during Docker image builds where no GPU is available (#1085).
  • wp.compile_aot_module() now skips recompilation when the output binary already exists and wp.config.cache_kernels is enabled (#1246).
  • A --no-cuda flag has been added to build_lib.py for explicit CPU-only builds (#1223).
  • The new wp.config.cuda_arch_suffix setting appends architecture-specific suffixes to the --gpu-architecture flag passed to NVRTC (#1065).
  • Device.max_shared_memory_per_block exposes the maximum shared memory per block for CUDA devices (#1243).
  • wp.HashGrid now supports wp.float16 and wp.float64 coordinate types (#1007, #1168).
  • Any integer type (not just wp.int32) can now be used when indexing vectors and matrices in kernels (#1209).
  • Shared tile allocations on the stack are now enabled for all CPU architectures by defaulting wp.config.enable_tiles_in_stack_memory to True (#1032).

New examples

The new example_fft_poisson_navier_stokes_2d example demonstrates 2-D incompressible turbulence using a vorticity-streamfunction formulation. It uses tile-based fast Fourier transforms (FFT) to solve the Poisson equation on a periodic domain and advances vorticity transport with strong-stability-preserving Runge-Kutta (SSP-RK3) time integration, showcasing how Warp's tile FFT primitives can be applied to spectral PDE solvers.

Announcements

Python 3.9 deprecation

Python 3.9 reached end-of-life on October 31, 2025 and no longer receives security updates. Warp v1.12 deprecates Python 3.9 support. A DeprecationWarning is emitted both at runtime (when import warp runs under Python 3.9) and at build time (when compiling the native library with a Python 3.9 interpreter). Support for Python 3.9 will be removed entirely in Warp 1.13. To continue receiving Warp updates beyond v1.12.1, please migrate to Python 3.10 or newer.

Deprecations and removals

Implicit conversion of scalar values to composite types (vectors, matrices, etc.) when launching kernels or assigning to struct fields is now deprecated. Use explicit constructors such as wp.vec3(...) or wp.mat22(...) (#1022). Constructing matrices from row vectors via wp.matrix() has been removed after being deprecated in v1.9 (#1179); use wp.matrix_from_rows() or wp.matrix_from_cols() instead.

API cleanup finalization

The internal API cleanup begun in Warp v1.11, which added deprecation warnings and forwarding calls for internal symbols accessed through the public warp namespace, will be finalized in Warp v1.13. At that point, the deprecation messages and forwarding shims will be removed, and code that still accesses deprecated internal APIs will break. If you have been seeing deprecation warnings about internal symbol access, please update your code before v1.13 is released. If you need help migrating, feel free to ask on GitHub Discussions.

Acknowledgments

We also thank the following contributors from outside the core Warp development team:

  • @lenroe for reducing CPU kernel launch overhead by caching internal dispatch structures and optimizing type comparisons (#1160).
  • @StafaH for contributing to the hardware-accelerated texture feature through review and helpful contributions (#1169).
  • @Cucchi01 for identifying and contributing a fix for warp.fem temporaries not being released promptly due to reference cycles (#1075).
  • @nawedume for removing the deprecated distutils import from the OpenGL rendering example (#1205).
  • @clatim for fixing links and updating documentation examples (#1245).

For a complete list of changes, see the full changelog.