v1.12.0 #1273
shi-eric
announced in
Announcements
v1.12.0
#1273
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Warp v1.12.0
Warp v1.12 adds experimental hardware-accelerated texture sampling on CUDA GPUs, extends tile programming with element-wise arithmetic operators and differentiable FFT, and broadens JAX interoperability with
jax.vmapsupport. This release also introduces subscript-style type hints for better IDE integration, new quaternion and approximate-math builtins, B-spline shape functions inwarp.fem, and a collection of utility and diagnostics APIs.New features
Hardware-accelerated textures
Warp v1.12 introduces
wp.Texture1D,wp.Texture2D, andwp.Texture3Dclasses that leverage CUDA texture memory for hardware-accelerated interpolation directly inside Warp kernels. On GPU, texture reads are routed through dedicated texture units that perform filtered lookups in a single instruction, making them ideal for rendering, volume sampling, signed-distance-field queries, and simulation lookup tables. On CPU, a software fallback provides identical semantics so the same kernel code runs on both devices.Key capabilities:
wp.Texture1D,wp.Texture2D,wp.Texture3D) with matchingwp.texture_sample()overloads that accept scalar,vec2f, orvec3fcoordinates.FILTER_POINTfor nearest-neighbor sampling andFILTER_LINEARfor bilinear (2D) or trilinear (3D) interpolation.ADDRESS_WRAP,ADDRESS_CLAMP,ADDRESS_MIRROR, andADDRESS_BORDERcontrol how out-of-range texture coordinates are handled, configurable per axis.copy_from_array()andcopy_to_array()methods to transfer data betweenwp.arrayobjects and texture memory. Acuda_surfaceproperty exposes the CUDA surface handle for advanced interop.Subscript-style type hints
When annotating kernel parameters with call-syntax forms like
wp.array(dtype=float), static type checkers such as Pyright and Pylance flag these as errors because the expressions look like constructor calls rather than type annotations. Warp v1.12 adds subscript-style alternatives that are recognized as valid generic aliases (#1216):The subscript syntax is supported for all array dimensionalities (
wp.array[dtype]throughwp.array4d[dtype]) as well aswp.tile[dtype]for tile-typed arguments.Warp's static type checking compatibility is being improved incrementally, and you may encounter other Pyright/Pylance diagnostics that are not yet resolved. If you run into type checking issues, please report them as sub-issues of #549.
Diagnostics utility
The new
wp.print_diagnostics()function displays a comprehensive snapshot of the Warp build and runtime environment (software versions, CUDA information, build flags, and available devices) in a single call (#1221). Two companion helpers,wp.get_cuda_toolkit_version()andwp.get_cuda_driver_version(), return the CUDA toolkit and driver versions as integer tuples (#1172). Together these are useful for debugging environment issues, capturing context in CI logs, and providing system information when filing bug reports.Quaternion and spatial helpers
Warp v1.12 adds quaternion and spatial transformation helpers:
wp.quat_from_euler(),wp.quat_to_euler(),wp.transform_twist(), andwp.transform_wrench()(#1237). The Euler conversion functions accept axis indices (0 = X, 1 = Y, 2 = Z) so you can specify arbitrary rotation-order conventions such as ZYX or XYZ, making them suitable for robotics and animation pipelines:Approximate math intrinsics
wp.div_approx()andwp.inverse_approx()expose GPU hardware fast-math instructions (div.approx.f32andrcp.approx.ftz.f64) for approximate floating-point division and reciprocal, offering higher throughput at reduced precision (#1199). Only floating-point types are supported. On CPU, both functions fall back to exact arithmetic so the same kernel code runs correctly on either device.Marching cubes lookup tables
The internal marching cubes lookup tables are now exposed as public class attributes on
wp.MarchingCubes:CUBE_CORNER_OFFSETS,EDGE_TO_CORNERS,CASE_TO_TRI_RANGE, andTRI_LOCAL_INDICES(#1151). These tables enable custom marching cubes implementations for advanced use cases such as sparse volume extraction or procedural mesh generation without having to duplicate the standard lookup data.Graph coloring API
wp.utils.graph_coloring_assign(),wp.utils.graph_coloring_balance(), andwp.graph_coloring_get_groups()are now part of the public API (#1145). These graph coloring utilities were originally introduced inwarp.simin v1.5.0 for use withVBDIntegratorand were removed along with thewarp.simmodule in v1.10.0. They are now re-introduced as standalone functions inwp.utils, independent of any physics module. They partition a graph into independent color groups, which is useful for parallel constraint solving, conflict-free mesh updates, and other tasks that require concurrent writes to non-adjacent elements.Tile programming enhancements
Tile arithmetic operators
Tiles now support native Python
*and/operators for element-wise multiplication and division, including broadcast between tiles and scalar constants (#1006, #1009). The supported forms aretile * tile,tile * constant,constant * tilefor multiplication, andtile / tile,tile / constant,constant / tilefor division. All combinations are differentiable and work with scalar, vector, and matrix element types.wp.tile_from_thread()wp.tile_from_thread()broadcasts a scalar or vector value held by a single thread to a shared tile visible to all threads in the block (#1178). This is useful when one thread computes a value (e.g., a reduction result or a loop-invariant parameter) that the entire block needs to use in subsequent tile operations. The function accepts athread_idxargument to specify which thread's value is broadcast, and supports both"shared"and"register"storage modes.Differentiable FFT
wp.tile_fft()andwp.tile_ifft()now support reverse-mode automatic differentiation when recorded on awp.Tape()(#1138). Warp automatically provides the correct gradient implementations for both transforms, so gradients propagate seamlessly through frequency-domain operations. This enables end-to-end gradient computation through pipelines that mix spatial and spectral steps, which is useful for differentiable signal processing, spectral methods, and PDE solvers.MathDx GEMM toggle
Setting
wp.config.enable_mathdx_gemm = False(or passing"enable_mathdx_gemm": Falseas a module option) disables cuBLASDx forwp.tile_matmul(), falling back to an optimized scalar GEMM implementation (#1228). This avoids the slow link-time optimization (LTO) step required by libmathdx during development iteration, while keeping libmathdx available for operations that have no scalar fallback, such as Cholesky factorization and FFT. The scalar fallback may be slower than cuBLASDx depending on tile sizes, data types, andblock_dim, so this option is primarily intended for faster compile–edit–run cycles during development rather than production use.Accelerated tile load/store
Shared-memory tile loads and stores via
wp.tile_load()/wp.tile_store()have been accelerated for non-power-of-two tile sizes (#1239). The improvement is most pronounced when source arrays fit within the GPU L2 cache, reducing overhead for tile-based kernels that operate on irregularly shaped data blocks.JAX integration
jax.vmapsupportjax.vmap()can now be used with Warp kernels and callables exposed throughjax_kernel()andjax_callable()(#859). This enables vectorized mapping over batched inputs, letting JAX automatically handle batching of Warp kernel invocations without manual loop constructs. Thevmap_methodparameter controls how batching is implemented."broadcast_all"broadcasts all arrays to include the batch dimension, while"sequential"iterates over the batch dimension one element at a time.has_side_effectflagThe new
has_side_effectflag onjax_kernel()andjax_callable()ensures that JAX does not optimize away Warp FFI calls whose outputs are unused downstream (#1240). This is important for kernels that write to global state or perform I/O as side effects. Without the flag, JAX's dead-code elimination may silently skip their execution. Settinghas_side_effect=Truemarks the call as effectful, forcing JAX to always execute it.warp.femenhancementsB-spline shape functions
Warp v1.12 adds B-spline basis functions to
warp.femwithSquareBSplineShapeFunctions(2D) andCubeBSplineShapeFunctions(3D), accessible viaElementBasis.BSPLINE(#1208). B-spline bases provide higher continuity across element boundaries compared to standard Lagrange bases, which is beneficial for applications requiring smooth solutions such as thin-shell mechanics or isogeometric analysis. Degrees 1 through 3 are supported onGrid2D,Grid3D, andNanogridgeometries.Other improvements
cells()operator now accepts traced fields, returning the underlying cell-level field for evaluation at cell-space samples (e.g., fromlookup()).PicQuadratureparticles can now span multiple cells by passing a tuple of 2D arrays(cell_indices, coords, particle_fraction)to specify per-particle cell contributions.quadratureanddomainarguments ofinterpolate(), and thespaceargument ofmake_space_restrictionandmake_space_partition(scheduled for removal in 1.14).Compilation and tooling
wp.compile_aot_module()to produce PTX/CUBIN during Docker image builds where no GPU is available ([REQ] Allow GPU kernels to be compiled without a GPU #1085).wp.compile_aot_module()now skips recompilation when the output binary already exists andwp.config.cache_kernelsis enabled (Skip recompilation incompile_aot_module()when cached binary exists #1246).--no-cudaflag has been added tobuild_lib.pyfor explicit CPU-only builds (Add--no-cudaflag tobuild_lib.pyfor explicit CPU-only builds #1223).wp.config.cuda_arch_suffixsetting appends architecture-specific suffixes to the--gpu-architectureflag passed to NVRTC ([REQ] Support foraandfBlackwell architecture variants #1065).Device.max_shared_memory_per_blockexposes the maximum shared memory per block for CUDA devices ([REQ] Expose device max shared memory as a public Device attribute #1243).wp.HashGridnow supportswp.float16andwp.float64coordinate types ([REQ] fp16 support for HashGrid #1007, [REQ] Add support of float64 in hash_grid #1168).wp.int32) can now be used when indexing vectors and matrices in kernels ([BUG] wp.argmin and wp.argmax return uint32 not int32 #1209).wp.config.enable_tiles_in_stack_memorytoTrue([BUG] Occasional failure intest_single_layer_nn_cpu#1032).New examples
The new
example_fft_poisson_navier_stokes_2dexample demonstrates 2-D incompressible turbulence using a vorticity-streamfunction formulation. It uses tile-based fast Fourier transforms (FFT) to solve the Poisson equation on a periodic domain and advances vorticity transport with strong-stability-preserving Runge-Kutta (SSP-RK3) time integration, showcasing how Warp's tile FFT primitives can be applied to spectral PDE solvers.Announcements
Python 3.9 deprecation
Python 3.9 reached end-of-life on October 31, 2025 and no longer receives security updates. Warp v1.12 deprecates Python 3.9 support. A
DeprecationWarningis emitted both at runtime (whenimport warpruns under Python 3.9) and at build time (when compiling the native library with a Python 3.9 interpreter). Support for Python 3.9 will be removed entirely in Warp 1.13. To continue receiving Warp updates beyond v1.12.1, please migrate to Python 3.10 or newer.Deprecations and removals
Implicit conversion of scalar values to composite types (vectors, matrices, etc.) when launching kernels or assigning to struct fields is now deprecated. Use explicit constructors such as
wp.vec3(...)orwp.mat22(...)(#1022). Constructing matrices from row vectors viawp.matrix()has been removed after being deprecated in v1.9 (#1179); usewp.matrix_from_rows()orwp.matrix_from_cols()instead.API cleanup finalization
The internal API cleanup begun in Warp v1.11, which added deprecation warnings and forwarding calls for internal symbols accessed through the public
warpnamespace, will be finalized in Warp v1.13. At that point, the deprecation messages and forwarding shims will be removed, and code that still accesses deprecated internal APIs will break. If you have been seeing deprecation warnings about internal symbol access, please update your code before v1.13 is released. If you need help migrating, feel free to ask on GitHub Discussions.Acknowledgments
We also thank the following contributors from outside the core Warp development team:
warp.femtemporaries not being released promptly due to reference cycles (Fix Explicit release of temporaries due to ref-cycle in fem.borrow_temporary #1075).distutilsimport from the OpenGL rendering example (Updated OpenGL rendering example to remove the distutils import #1205).cgexample in documentation. #1245).For a complete list of changes, see the full changelog.
This discussion was created from the release v1.12.0.
Beta Was this translation helpful? Give feedback.
All reactions