diff --git a/DispatchPreprocessingProposal.md b/DispatchPreprocessingProposal.md new file mode 100644 index 00000000..445164ad --- /dev/null +++ b/DispatchPreprocessingProposal.md @@ -0,0 +1,368 @@ +# Proposal: Dispatch Preprocessing Pass + +> **TL;DR** — Before the [functional GridBatch refactor](FunctionalRefactorProposal.md), create a new `fvdb/detail/dispatch/` module with device-unified forEach tools aligned with the new dispatch framework, then port ops one-by-one to use them and internalize their own device dispatch. After this pass, every op header exposes a type-erased function — no templates leak to callers. The first PR introduces the new tools and ports a minimum set of ops to prove the pattern. Subsequent PRs are mechanical "more of the same." + +--- + +**Contents** + +- [Motivation](#motivation) +- [Problems with the Current forEach Infrastructure](#problems-with-the-current-foreach-infrastructure) +- [The New `fvdb/detail/dispatch/` Module](#the-new-fvdbdetaildispatch-module) +- [How Ops Change](#how-ops-change) +- [Deprecating SimpleOpHelper](#deprecating-simpleophelper) +- [Op Inventory](#op-inventory) +- [PR Strategy](#pr-strategy) +- [Relationship to the Functional Refactor](#relationship-to-the-functional-refactor) + +--- + +## Motivation + +Currently, most ops expose a `template ` function in their header (see [How Ops Change](#how-ops-change) for a concrete before/after). Every caller must wrap the call in `FVDB_DISPATCH_KERNEL_DEVICE` to select the device at runtime. This has three costs: + +1. **Template leakage**: The `DeviceTag` template propagates from the op through the caller. Callers must be compiled in a CUDA-aware context even if they contain no device code. +2. **Scattered dispatch**: `FVDB_DISPATCH_KERNEL_DEVICE` appears in [`GridBatchImpl.cu`](src/fvdb/detail/GridBatchImpl.cu) (~8 sites) and [`GridBatch.cpp`](src/fvdb/GridBatch.cpp) (~16 sites). The [functional refactor](FunctionalRefactorProposal.md) deletes `GridBatch.cpp`, so these sites need a home. If they live inside the ops, the functional refactor becomes purely a Python-side change. +3. **Inconsistency**: The newer ops (morton/hilbert) already internalize dispatch with type-erased entry points. Most ops don't. + +But more fundamentally, the reason each op has to do per-device dispatch manually is because the underlying forEach tools are per-device. Fixing the tools fixes everything downstream. + +--- + +## Problems with the Current forEach Infrastructure + +The current iteration primitives ([`ForEachCPU.h`](src/fvdb/detail/utils/ForEachCPU.h), [`ForEachCUDA.cuh`](src/fvdb/detail/utils/cuda/ForEachCUDA.cuh), [`ForEachPrivateUse1.cuh`](src/fvdb/detail/utils/cuda/ForEachPrivateUse1.cuh)) have several design problems: + +### 1. Per-device triplication + +There are three separate functions for every iteration pattern: + +```cpp +forEachVoxelCPU(numChannels, batchHdl, func, args...); +forEachVoxelCUDA(numThreads, numChannels, batchHdl, func, args...); +forEachVoxelPrivateUse1(numChannels, batchHdl, func, args...); +``` + +Every op must choose the right one at every call site, leading to the pervasive `if constexpr (DeviceTag == torch::kCUDA) { ... } else if ...` pattern — even when the callback is `__hostdev__` and is identical across all devices. + +### 2. Incompatible signatures + +The three variants don't even have the same parameter lists. CUDA takes `numThreads` and `stream`; CPU doesn't. This means you can't write a single call that works on all devices. + +### 3. Callback signature overload + +Callbacks receive raw indices `(bidx, lidx, vidx, cidx, accessor, args...)` — a positional parade of integers with no type safety. Each op must manually decode these into meaningful voxel coordinates and feature indices. + +### 4. Channel parallelism entangled with iteration + +The `numChannels` / `channelsPerLeaf` parameter is baked into the iteration loop, mixing parallelism strategy with element traversal. This creates complex index arithmetic in every callback. + +### 5. Template parameter explosion + +Every forEach function is templated on ``. This interacts with the per-device triplication to create a combinatorial explosion of template instantiations. + +[`SimpleOpHelper.h`](src/fvdb/detail/utils/SimpleOpHelper.h) was a step in the right direction — it introduced CRTP base classes (`BasePerActiveVoxelProcessor`, `BasePerElementProcessor`) that handle output allocation and hide the per-device dispatch behind `if constexpr`. But it still wraps the old forEach tools and inherits their problems. And CRTP is needless ceremony for what are conceptually just callbacks. + +--- + +## The New `fvdb/detail/dispatch/` Module + +The new module provides **device-unified** iteration tools aligned with the new dispatch framework's design philosophy: write device variations only when you need to, and never more. + +### Core design principles + +1. **Single call site**: One function call iterates over the grid/tensor. The function handles CPU/CUDA/PrivateUse1 internally based on the input's device. No `if constexpr` in op code. +2. **`__hostdev__` by default**: Most callbacks are `__hostdev__` and run identically on any device. But when an op needs device-specific behavior (vectorized intrinsics, warp-level primitives, etc.), it can still provide per-device specializations — the point is that the *framework* doesn't force you into per-device code when the algorithm doesn't need it. +3. **No CRTP**: Callbacks are plain callables (lambdas, structs with `operator()`). No base classes to inherit from. +4. **Separated concerns**: Iteration (what to visit) is separate from parallelism (how many threads/channels). Sensible defaults for thread counts; explicit override when needed. + +Not all ops will use the forEach tools. Ops with heavily device-specific internals (e.g., sparse convolution with CUDA-specific memory access patterns, or trilinear sampling with float4 vectorization) may continue to have per-device code paths internally. The forEach tools are for the common case — and that common case covers the majority of ops. For the rest, the type-erased wrapper pattern still applies: internalize the dispatch, expose a non-templated header. + +### `forEachActiveVoxel` + +Iterates over active voxels in a `GridBatchImpl`, calling a `__hostdev__` callback for each: + +```cpp +// fvdb/detail/dispatch/ForEachActiveVoxel.cuh +namespace fvdb::detail::dispatch { + +// Callback: void(nanovdb::Coord ijk, int64_t voxel_index, GridBatchImpl::Accessor acc, Args...) +// where voxel_index is the linear feature index for this voxel +template +void forEachActiveVoxel(const GridBatchImpl &grid, Callback &&callback, Args&&... args); + +} +``` + +An op using this: + +```cpp +// No device template. No CRTP. Just a __hostdev__ callable. +struct WriteCoords { + __hostdev__ void operator()(nanovdb::Coord ijk, int64_t idx, + GridBatchImpl::Accessor, auto out) const { + out[idx][0] = ijk[0]; + out[idx][1] = ijk[1]; + out[idx][2] = ijk[2]; + } +}; + +JaggedTensor activeGridCoords(const GridBatchImpl &grid) { + auto out = allocateOutput(grid, FixedShape{}); + auto acc = makeAccessor(out); + dispatch::forEachActiveVoxel(grid, WriteCoords{}, acc); + return grid.jaggedTensor(out); +} +``` + +Compare this to the current version in [`ActiveGridGoords.cu`](src/fvdb/detail/ops/ActiveGridGoords.cu), which requires a CRTP base class, explicit template instantiations for three device types, and a templated `dispatch*` entry point. + +### `forEachLeaf` + +For leaf-level iteration: + +```cpp +// Callback: void(int64_t batch_idx, int64_t leaf_idx, GridBatchImpl::Accessor acc, Args...) +template +void forEachLeaf(const GridBatchImpl &grid, Callback &&callback, Args&&... args); +``` + +### `forEachJaggedElement` + +For iteration over JaggedTensor elements. Note: `ScalarT` and `NDIMS` template parameters are retained because they determine the accessor type for type-safe tensor access — this is unavoidable without runtime type erasure on the tensor side. + +```cpp +// Callback: void(int64_t batch_idx, int64_t element_idx, JaggedAccessor acc, Args...) +template +void forEachJaggedElement(const JaggedTensor &jt, Callback &&callback, Args&&... args); +``` + +### `forEachTensorElement` + +For plain tensor iteration: + +```cpp +// Callback: void(int64_t element_idx, TensorAccessor acc, Args...) +template +void forEachTensorElement(const torch::Tensor &tensor, Callback &&callback, Args&&... args); +``` + +### With-channels variants + +For ops that need explicit channel parallelism (e.g., trilinear interpolation across feature channels), there would be `WithChannels` variants that expose `channel_idx` in the callback. But the default is no channel dimension — most ops don't need it. + +### Internal implementation + +Each dispatch tool has a `.cuh` header (since it instantiates CUDA kernels internally) that: +1. Checks `grid.device()` / `tensor.device()` +2. Dispatches to the appropriate existing low-level forEach (CPU loop, CUDA kernel launch, PrivateUse1 kernel launch) +3. Provides sensible defaults for thread count and shared memory + +The low-level per-device forEach functions ([`ForEachCPU.h`](src/fvdb/detail/utils/ForEachCPU.h), [`ForEachCUDA.cuh`](src/fvdb/detail/utils/cuda/ForEachCUDA.cuh), [`ForEachPrivateUse1.cuh`](src/fvdb/detail/utils/cuda/ForEachPrivateUse1.cuh)) remain as the backend — they're correct and tested. The new dispatch module is a unifying layer on top, not a rewrite of the kernel launch mechanics. + +--- + +## How Ops Change + +### Before (old-style header) + +```cpp +// src/fvdb/detail/ops/CoordsInGrid.h (current) +#include +#include // full include + +template // template leaks to caller +JaggedTensor dispatchCoordsInGrid(const GridBatchImpl &batchHdl, const JaggedTensor &coords); +``` + +### After (type-erased header) + +```cpp +// src/fvdb/detail/ops/CoordsInGrid.h (target) +#include + +namespace fvdb::detail { class GridBatchImpl; } // forward declaration only + +namespace fvdb::detail::ops { +JaggedTensor coordsInGrid(const GridBatchImpl &grid, const JaggedTensor &coords); +} +``` + +### The `.cu` file + +The internal logic is unchanged. What changes: +1. The device dispatch moves from the caller into the op's `.cu` file +2. The templated `dispatch*` function becomes `static` (file-internal) or is replaced entirely by using the new dispatch forEach tools +3. A non-templated wrapper is the public entry point +4. Explicit template instantiations at the bottom of the file are removed + +For ops that currently use `BasePerActiveVoxelProcessor`, the CRTP base class is replaced by a direct call to `dispatch::forEachActiveVoxel` with a plain callback. + +For complex ops (trilinear sampling, ray ops) that have custom iteration logic, the change is minimal: the existing templated function becomes `static`, and a ~5-line type-erased wrapper calls `FVDB_DISPATCH_KERNEL` to select the device and calls the static function. + +--- + +## Deprecating SimpleOpHelper + +[`SimpleOpHelper.h`](src/fvdb/detail/utils/SimpleOpHelper.h) was a useful stepping stone. It introduced several good ideas: +- Element type descriptors (`FixedElementType`, `DynamicElementType`, `ScalarElementType`) +- Output tensor allocation helpers (`makeOutTensorFromGridBatch`, `makeOutTensorFromTensor`) +- Accessor creation helpers (`makeAccessor`) + +These utilities are worth keeping (or evolving) as standalone helpers in `fvdb/detail/dispatch/`. What gets deprecated is the CRTP base class pattern (`BasePerActiveVoxelProcessor`, `BasePerElementProcessor`), which is unnecessary once the dispatch forEach tools unify device dispatch. A plain `__hostdev__` callable + `dispatch::forEachActiveVoxel` is simpler and more composable than inheriting from a CRTP base. + +The output allocation and accessor helpers can live in `fvdb/detail/dispatch/OutputHelpers.h` or similar, available to all ops without requiring CRTP. + +--- + +## Op Inventory + +~50 ops need migration. They fall into categories by how much the new forEach tools can simplify them. + +### High simplification (currently use or can use per-voxel/per-element iteration) + +These benefit most from `forEachActiveVoxel` / `forEachTensorElement`. The CRTP base class and explicit template instantiations are eliminated entirely. + +| Op | Current pattern | Notes | +|:---|:----------------|:------| +| [`ActiveGridGoords`](src/fvdb/detail/ops/ActiveGridGoords.h) | `BasePerActiveVoxelProcessor` | Poster child for the new pattern | +| [`SerializeEncode`](src/fvdb/detail/ops/SerializeEncode.h) | `BasePerActiveVoxelProcessor` | | +| [`CoordsInGrid`](src/fvdb/detail/ops/CoordsInGrid.h) | Custom per-voxel | | +| [`IjkToIndex`](src/fvdb/detail/ops/IjkToIndex.h) | Custom per-voxel | | +| [`IjkToInvIndex`](src/fvdb/detail/ops/IjkToInvIndex.h) | Custom per-voxel | | +| [`PointsInGrid`](src/fvdb/detail/ops/PointsInGrid.h) | Custom per-element | | +| [`CubesInGrid`](src/fvdb/detail/ops/CubesInGrid.h) | Custom per-element | Also needs `Vec3dOrScalar` cleanup | +| [`JIdxForGrid`](src/fvdb/detail/ops/JIdxForGrid.h) | Custom per-voxel | | +| [`GridEdgeNetwork`](src/fvdb/detail/ops/GridEdgeNetwork.h) | Custom per-voxel | | +| [`Inject`](src/fvdb/detail/ops/Inject.h) | Custom per-voxel | | +| [`MortonHilbertFromIjk`](src/fvdb/detail/ops/MortonHilbertFromIjk.h) | `BasePerElementProcessor` | Already has type-erased entry point | + +### Medium simplification (custom loops, use `forEachJaggedElement`) + +These use `forEachJaggedElementChannel*` with custom callbacks. They benefit from `dispatch::forEachJaggedElement` (single call site) but keep their custom callback logic. + +| Op | Notes | +|:---|:------| +| [`SampleGridTrilinear`](src/fvdb/detail/ops/SampleGridTrilinear.h) | `AT_DISPATCH_V2` + float4 vectorization | +| [`SampleGridTrilinearWithGrad`](src/fvdb/detail/ops/SampleGridTrilinearWithGrad.h) | | +| [`SampleGridTrilinearWithGradBackward`](src/fvdb/detail/ops/SampleGridTrilinearWithGradBackward.h) | | +| [`SampleGridBezier`](src/fvdb/detail/ops/SampleGridBezier.h) | | +| [`SampleGridBezierWithGrad`](src/fvdb/detail/ops/SampleGridBezierWithGrad.h) | | +| [`SampleGridBezierWithGradBackward`](src/fvdb/detail/ops/SampleGridBezierWithGradBackward.h) | | +| [`SplatIntoGridTrilinear`](src/fvdb/detail/ops/SplatIntoGridTrilinear.h) | | +| [`SplatIntoGridBezier`](src/fvdb/detail/ops/SplatIntoGridBezier.h) | | +| [`TransformPointToGrid`](src/fvdb/detail/ops/TransformPointToGrid.cu) | | +| [`DownsampleGridAvgPool`](src/fvdb/detail/ops/DownsampleGridAvgPool.h) | | +| [`DownsampleGridMaxPool`](src/fvdb/detail/ops/DownsampleGridMaxPool.cu) | | +| [`UpsampleGridNearest`](src/fvdb/detail/ops/UpsampleGridNearest.h) | | +| [`VoxelNeighborhood`](src/fvdb/detail/ops/VoxelNeighborhood.h) | | +| [`ReadFromDense`](src/fvdb/detail/ops/ReadFromDense.h) | Also needs `Vec3iBatch` cleanup | +| [`ReadIntoDense`](src/fvdb/detail/ops/ReadIntoDense.h) | Also needs `Vec3iBatch` cleanup | +| [`NearestIjkForPoints`](src/fvdb/detail/ops/NearestIjkForPoints.h) | | +| [`CoarseIjkForFineGrid`](src/fvdb/detail/ops/CoarseIjkForFineGrid.h) | | +| [`ActiveVoxelsInBoundsMask`](src/fvdb/detail/ops/ActiveVoxelsInBoundsMask.h) | | + +### Minimal change (complex internals, just add type-erased wrapper) + +These have complex multi-pass logic. The main change is adding a type-erased wrapper; internal forEach migration is optional and can be done later. + +| Op | Notes | +|:---|:------| +| [`VoxelsAlongRays`](src/fvdb/detail/ops/VoxelsAlongRays.h) | Multi-output | +| [`SegmentsAlongRays`](src/fvdb/detail/ops/SegmentsAlongRays.h) | | +| [`SampleRaysUniform`](src/fvdb/detail/ops/SampleRaysUniform.h) | Many parameters | +| [`RayImplicitIntersection`](src/fvdb/detail/ops/RayImplicitIntersection.h) | | +| [`MarchingCubes`](src/fvdb/detail/ops/MarchingCubes.h) | Multi-output | +| [`IntegrateTSDF`](src/fvdb/detail/ops/IntegrateTSDF.h) | Complex multi-output | +| [`VolumeRender`](src/fvdb/detail/ops/VolumeRender.h) | | +| [`ConvolutionKernelMap`](src/fvdb/detail/ops/convolution/pack_info/ConvolutionKernelMap.h) | Also needs `Vec3iOrScalar` cleanup | + +### Grid-building ops (return `GridHandle`) + +| Op | Notes | +|:---|:------| +| [`BuildGridFromIjk`](src/fvdb/detail/ops/BuildGridFromIjk.h) | | +| [`BuildGridFromPoints`](src/fvdb/detail/ops/BuildGridFromPoints.h) | | +| [`BuildGridFromMesh`](src/fvdb/detail/ops/BuildGridFromMesh.h) | | +| [`BuildGridFromNearestVoxelsToPoints`](src/fvdb/detail/ops/BuildGridFromNearestVoxelsToPoints.h) | | +| [`BuildCoarseGridFromFine`](src/fvdb/detail/ops/BuildCoarseGridFromFine.h) | | +| [`BuildFineGridFromCoarse`](src/fvdb/detail/ops/BuildFineGridFromCoarse.h) | | +| [`BuildDenseGrid`](src/fvdb/detail/ops/BuildDenseGrid.h) | | +| [`BuildDilatedGrid`](src/fvdb/detail/ops/BuildDilatedGrid.h) | | +| [`BuildMergedGrids`](src/fvdb/detail/ops/BuildMergedGrids.h) | | +| [`BuildPaddedGrid`](src/fvdb/detail/ops/BuildPaddedGrid.h) | | +| [`BuildPrunedGrid`](src/fvdb/detail/ops/BuildPrunedGrid.h) | | +| [`BuildGridForConv`](src/fvdb/detail/ops/BuildGridForConv.h) | | +| [`PopulateGridMetadata`](src/fvdb/detail/ops/PopulateGridMetadata.h) | | + +### JaggedTensor-only ops + +| Op | Notes | +|:---|:------| +| [`JaggedTensorIndex`](src/fvdb/detail/ops/JaggedTensorIndex.h) | | +| [`JCat0`](src/fvdb/detail/ops/JCat0.h) | | +| [`JIdxForJOffsets`](src/fvdb/detail/ops/JIdxForJOffsets.h) | | +| [`JOffsetsFromJIdx`](src/fvdb/detail/ops/JOffsetsFromJIdx.h) | | +| [`IjkForMesh`](src/fvdb/detail/ops/IjkForMesh.h) | | + +--- + +## PR Strategy + +The key insight for review is: **the new dispatch tools are the only novel code**. Every subsequent op migration is a mechanical application of a proven recipe. So the PR strategy front-loads the scrutiny. + +### PR 1: New dispatch tools + minimum viable op ports (small, reviewable) + +**What's in it:** +- New `src/fvdb/detail/dispatch/` directory with: + - `ForEachActiveVoxel.cuh` — device-unified active-voxel iteration over `GridBatchImpl` + - `ForEachLeaf.cuh` — device-unified leaf iteration over `GridBatchImpl` + - `ForEachJaggedElement.cuh` — device-unified element iteration over `JaggedTensor` + - `ForEachTensorElement.cuh` — device-unified element iteration over `torch::Tensor` + - `OutputHelpers.h` — output allocation and accessor utilities (extracted from `SimpleOpHelper.h`) +- Port of **3-4 ops** that cover the main patterns: + - `ActiveGridGoords` — proves `forEachActiveVoxel` (currently `BasePerActiveVoxelProcessor`) + - `MortonHilbertFromIjk` — proves `forEachTensorElement` (currently `BasePerElementProcessor`); already has type-erased entry point, just needs to drop `SimpleOpHelper` dependency + - `CoordsInGrid` — proves `forEachActiveVoxel` for an op that takes additional input (currently custom per-voxel) + - `SampleGridTrilinear` — proves the "just add a type-erased wrapper" pattern for a complex op +- Removal of `FVDB_DISPATCH_KERNEL_DEVICE` from the call sites of these 4 ops in `GridBatchImpl.cu` and `GridBatch.cpp` +- Old forEach tools remain untouched — no breakage + +**What reviewers scrutinize:** +- The design of the dispatch forEach tools (the only new code) +- Whether the ported ops maintain identical behavior (tests pass) +- Whether the pattern is clear enough to replicate mechanically + +**What reviewers don't need to worry about:** +- Changes to GridBatchImpl's interface +- Changes to the Python API +- Changes to the autograd layer + +### PR 2-N: Bulk op ports (large, mechanical, low-scrutiny) + +Once the pattern from PR 1 is approved, the remaining ~45 ops are ported in batches. Each PR: +- Picks a group of ops (5-10 at a time, grouped by pattern similarity) +- Applies the same recipe +- Removes `FVDB_DISPATCH_KERNEL_DEVICE` from corresponding call sites +- Runs tests + +These PRs are "another of" — the reviewers have already approved the pattern. + +### Final PR: Cleanup + +- Mark `SimpleOpHelper.h`'s CRTP bases as deprecated (or delete if all consumers are ported) +- Verify no op header contains `template ` +- Verify `FVDB_DISPATCH_KERNEL` / `FVDB_DISPATCH_KERNEL_DEVICE` no longer appears outside of `src/fvdb/detail/dispatch/` and individual op `.cu` files that need it for scalar-type dispatch +- Old per-device forEach tools can be deprecated (they're still correct, just no longer the recommended path) + +--- + +## Relationship to the Functional Refactor + +This preprocessing pass is designed to make the [functional GridBatch refactor](FunctionalRefactorProposal.md) dramatically simpler. + +**Before this pass**: The functional refactor needs ~25 thin C++ wrapper functions. About 15 of them exist solely to do `FVDB_DISPATCH_KERNEL_DEVICE` because the ops expose templated headers. + +**After this pass**: Ops expose type-erased functions. The functional refactor's thin C++ wrappers are only needed for the ~12 `autograd::Function::apply()` calls. For everything else, Python can call the type-erased op function through a direct pybind binding — no C++ wrapper needed at all. + +Additionally, `GridBatchImpl.cu`'s derived-grid methods (`coarsen`, `upsample`, etc.) stop using `FVDB_DISPATCH_KERNEL` — they just call the type-erased op functions. This simplifies `GridBatchImpl` even before the functional refactor touches it. diff --git a/FunctionalRefactorProposal.md b/FunctionalRefactorProposal.md new file mode 100644 index 00000000..531d61e4 --- /dev/null +++ b/FunctionalRefactorProposal.md @@ -0,0 +1,367 @@ +# Proposal: Functional Refactor of the GridBatch Pipeline + +> **TL;DR** — Eliminate two entire C++ abstraction layers (~2,600 lines) by moving validation and ergonomic logic to Python — where it already largely exists — and replacing the wrapper code with ~25 thin, strict free functions that do only what C++ must do: autograd tape registration and CPU/CUDA template dispatch. **The public-facing Python API does not change. Existing Python tests should pass without modification.** + +--- + +**Contents** + +- [Motivation](#motivation) +- [What Changes](#what-changes) +- [What the Thin C++ Functions Look Like](#what-the-thin-c-functions-look-like) +- [What Python Gains](#what-python-gains) +- [What GridBatchImpl Becomes](#what-gridbatchimpl-becomes) +- [Immutability Cleanup in C++](#immutability-cleanup-in-c) +- [Functional Python API](#functional-python-api) +- [Design Principle: Strict in C++, Ergonomic in Python](#design-principle-strict-in-c-ergonomic-in-python) +- [Why This Is Safe](#why-this-is-safe) +- [Why This Can Be Largely Automated](#why-this-can-be-largely-automated) +- [Phasing](#phasing) +- [Metrics](#metrics) + +--- + +## Motivation + +### The current pipeline has four layers for every operation + +``` +Python (fvdb/grid_batch.py, fvdb/grid.py) + -> pybind binding (src/python/GridBatchBinding.cpp) + -> C++ wrapper class (src/fvdb/GridBatch.h, src/fvdb/GridBatch.cpp) + -> C++ core (src/fvdb/detail/GridBatchImpl.h, src/fvdb/detail/GridBatchImpl.cu) + -> dispatch / ops / autograd (actual kernels) +``` + +Most operations pass through all four layers with each one doing progressively less work. By the time you reach `GridBatchImpl`, the actual computation is typically a single dispatch call. The two middle layers — `GridBatch` and the binding — exist primarily to convert types and re-validate inputs that Python already validated. + +### The wrapper layer is already bypassed by the code that matters + +The autograd functions — the differentiable operations that are the heart of the library — **already work directly with `GridBatchImpl`**. Every single autograd function in [`src/fvdb/detail/autograd/`](src/fvdb/detail/autograd/) saves and restores `c10::intrusive_ptr` in the autograd context. The `GridBatch` wrapper class is invisible to them. + +The ops layer is the same story. Every dispatch function in [`src/fvdb/detail/ops/`](src/fvdb/detail/ops/) takes `const GridBatchImpl&`. No op or autograd function includes `GridBatch.h` or references the `GridBatch` class. + +### The C++ duck-typing system duplicates Python's job + +The `Vec3iOrScalar`, `Vec3dBatch`, `Vec3dBatchOrScalar` type system ([`TypesImpl.h`](src/fvdb/detail/TypesImpl.h), [`Types.h`](src/fvdb/Types.h), [`TypeCasters.h`](src/python/TypeCasters.h) — ~500 lines of template metaprogramming) exists to accept flexible inputs from Python (`1.0`, `[1,2,3]`, `torch.tensor([1,2,3])`) and convert them to strict NanoVDB types. But Python's [`fvdb/types.py`](fvdb/types.py) already provides `to_Vec3iBroadcastable`, `to_Vec3fBatchBroadcastable`, etc. — doing the same conversion with better error messages and simpler code. Every call site in [`grid_batch.py`](fvdb/grid_batch.py) already invokes the Python converter before crossing into C++, making the C++ conversion redundant. + +### Real cost + +The redundancy has concrete costs: + +- **Compile time**: [`GridBatch.h`](src/fvdb/GridBatch.h) pulls in [`GridBatchImpl.h`](src/fvdb/detail/GridBatchImpl.h), all of `Types.h`/`TypesImpl.h`, and transitively much of NanoVDB. Every translation unit that touches this pays the price. +- **Comprehension time**: A new contributor tracing `sample_trilinear` must read through four files and two type-conversion systems to find the actual dispatch call. +- **Maintenance cost**: Adding a new parameter to an operation requires changes in `GridBatchImpl`, `GridBatch`, the binding, and the Python wrapper — four coordinated edits for one logical change. + +--- + +## What Changes + +### Files deleted + +| File | Lines | Role | +|:-----|------:|:-----| +| [`src/fvdb/GridBatch.h`](src/fvdb/GridBatch.h) | 912 | Wrapper class declaration | +| [`src/fvdb/GridBatch.cpp`](src/fvdb/GridBatch.cpp) | 1,232 | Wrapper class implementation | +| [`src/fvdb/detail/TypesImpl.h`](src/fvdb/detail/TypesImpl.h) | 281 | `Vec3*OrScalar` / `Vec3*Batch` templates | +| [`src/fvdb/Types.h`](src/fvdb/Types.h) | ~40 of 49 | Type aliases (keep `SpaceFillingCurveType` enum) | +| [`src/python/TypeCasters.h`](src/python/TypeCasters.h) | ~120 of 165 | pybind type casters for the above (keep `ScalarType` caster if needed) | +| **Total removed** | **~2,585** | | + +### Files added or modified + +| File | Change | +|:-----|:-------| +| **New:** `src/python/GridBatchOps.cpp` (~200 lines) | ~25 thin free functions bound via pybind11 ([see below](#what-the-thin-c-functions-look-like)) | +| [`src/python/GridBatchBinding.cpp`](src/python/GridBatchBinding.cpp) | Simplified — binds `GridBatchImpl` directly + the free functions | +| [`src/python/Bindings.cpp`](src/python/Bindings.cpp) | Remove `m.class_` registration | +| [`fvdb/grid_batch.py`](fvdb/grid_batch.py) | Calls thin C++ functions directly; gains validation that was in `GridBatch.cpp` | +| [`fvdb/_fvdb_cpp.pyi`](fvdb/_fvdb_cpp.pyi) | Updated stubs to reflect new binding surface | +| [`src/fvdb/detail/ops/CubesInGrid.h`](src/fvdb/detail/ops/CubesInGrid.h), [`.cu`](src/fvdb/detail/ops/CubesInGrid.cu) | Change `Vec3dOrScalar` params to strict `nanovdb::Vec3d` | +| [`src/fvdb/detail/ops/convolution/pack_info/ConvolutionKernelMap.h`](src/fvdb/detail/ops/convolution/pack_info/ConvolutionKernelMap.h), [`.cu`](src/fvdb/detail/ops/convolution/pack_info/ConvolutionKernelMap.cu) | Change `Vec3iOrScalar` params to strict `nanovdb::Coord` | +| [`src/fvdb/detail/autograd/ReadFromDense.h`](src/fvdb/detail/autograd/ReadFromDense.h), [`.cpp`](src/fvdb/detail/autograd/ReadFromDense.cpp) | Change `Vec3iBatch` params to `std::vector` | +| [`src/fvdb/detail/autograd/ReadIntoDense.h`](src/fvdb/detail/autograd/ReadIntoDense.h), [`.cpp`](src/fvdb/detail/autograd/ReadIntoDense.cpp) | Change `Vec3iBatch` params to `std::vector` | +| [`src/fvdb/detail/viewer/Viewer.h`](src/fvdb/detail/viewer/Viewer.h) | Remove unused `#include ` | + +### Files untouched + +- **All ops** ([`src/fvdb/detail/ops/`](src/fvdb/detail/ops/)) — already depend only on `GridBatchImpl` (the 2 ops listed above change only their `Vec3*OrScalar` parameter types, not their logic) +- **All autograd functions** ([`src/fvdb/detail/autograd/`](src/fvdb/detail/autograd/)) — already depend only on `GridBatchImpl` (the 2 autograd files listed above change only their `Vec3iBatch` parameter types) +- **`GridBatchImpl`** ([`.h`](src/fvdb/detail/GridBatchImpl.h), [`.cu`](src/fvdb/detail/GridBatchImpl.cu)) — the core implementation is unchanged +- **The viewer** ([`src/fvdb/detail/viewer/`](src/fvdb/detail/viewer/)) — one dead include removed, no API change +- **Python tests** — the public API is identical; tests should pass as-is + +--- + +## What the Thin C++ Functions Look Like + +There are exactly two reasons code must stay in C++: + +1. **`torch::autograd::Function::apply()`** — registers nodes in the autograd tape. This is a C++ API. +2. **`FVDB_DISPATCH_KERNEL_DEVICE`** — compile-time template instantiation for CPU/CUDA code paths. + +Each thin wrapper does one of these and nothing else. No validation, no type conversion, no default handling. + +### Autograd wrapper example + +```cpp +// ~5 lines. No validation, no type conversion, no GridBatch class. +std::vector +sample_trilinear_autograd(c10::intrusive_ptr grid, + JaggedTensor points, + torch::Tensor data, + bool return_grad) { + return autograd::SampleGridTrilinear::apply(grid, points, data, return_grad); +} +``` + +### Device dispatch wrapper example + +```cpp +// ~6 lines. Just the template dispatch that Python can't do. +std::vector +marching_cubes_dispatch(const GridBatchImpl &grid, + torch::Tensor field, + double level) { + return FVDB_DISPATCH_KERNEL_DEVICE(grid.device(), [&]() { + return ops::dispatchMarchingCubes(grid, field, level); + }); +} +``` + +These are bound as module-level functions: + +```cpp +m.def("sample_trilinear", &sample_trilinear_autograd, ...); +m.def("marching_cubes", &marching_cubes_dispatch, ...); +``` + +> **Note:** If the [dispatch preprocessing pass](DispatchPreprocessingProposal.md) is done first, the device dispatch wrappers (type 2 above) are no longer needed — the ops themselves handle dispatch internally. Only the ~12 autograd `::apply()` wrappers remain, and everything else can be bound directly from the type-erased op functions. This reduces the thin wrapper count from ~25 to ~12. + +--- + +## What Python Gains + +Python takes ownership of everything that doesn't require C++ compilation. + +**Before** — Python delegates everything to the C++ wrapper: + +```python +# fvdb/grid_batch.py (current) +def sample_trilinear(self, points: JaggedTensor, voxel_data: JaggedTensor) -> JaggedTensor: + return JaggedTensor(impl=self._impl.sample_trilinear(points._impl, voxel_data._impl)) +``` + +**After** — Python owns validation, calls C++ only for the autograd apply: + +```python +# fvdb/grid_batch.py (proposed) +def sample_trilinear(self, points: JaggedTensor, voxel_data: JaggedTensor) -> JaggedTensor: + if points.ldim() != 1: + raise ValueError("Expected points to have 1 list dimension") + if voxel_data.ldim() != 1: + raise ValueError("Expected voxel_data to have 1 list dimension") + result = _fvdb_cpp.sample_trilinear( + self._grid_batch_impl, points._impl, voxel_data.jdata, False + ) + return points.jagged_like(result[0]) +``` + +For more complex operations like `max_pool`, Python absorbs the default-stride logic and coarse-grid creation that currently lives in [`GridBatch.cpp`](src/fvdb/GridBatch.cpp): + +```python +# fvdb/grid_batch.py (proposed) +def max_pool(self, pool_factor, data, stride=0, coarse_grid=None): + pool_factor = to_Vec3iBroadcastable(pool_factor, value_constraint=ValueConstraint.POSITIVE) + stride = to_Vec3iBroadcastable(stride, value_constraint=ValueConstraint.NON_NEGATIVE) + if (stride == 0).all(): + stride = pool_factor + if data.ldim() != 1: + raise ValueError("Expected data to have 1 list dimension") + if coarse_grid is None: + coarse_grid = self.coarsened_grid(stride) + result = _fvdb_cpp.max_pool_autograd( + self._grid_batch_impl, coarse_grid._grid_batch_impl, + pool_factor.tolist(), stride.tolist(), data.jdata + ) + return coarse_grid.jagged_like(result[0]), coarse_grid +``` + +This is the same logic that `GridBatch.cpp` currently contains — moved to where it's easier to read, test, and modify. + +--- + +## What GridBatchImpl Becomes + +[`GridBatchImpl`](src/fvdb/detail/GridBatchImpl.h) is already the right shape. It holds the core data (NanoVDB grid handle, metadata arrays, batch tensors), provides the `Accessor` for CUDA kernels, and supports indexing/slicing as views. + +The derived-grid methods currently on `GridBatchImpl` (`coarsen`, `upsample`, `dual`, `clip`, `dilate`, `merge`, `prune`, `convolutionOutput`) can optionally be extracted as free functions in a later phase, since they follow a pure functional pattern: read from `const GridBatchImpl&`, create a new `GridBatchImpl`, return it. But this is not required for the initial refactor — they already work correctly as methods and the ops/autograd layer doesn't care either way. + +`GridBatchImpl` retains its `CustomClassHolder` base and `TORCH_LIBRARY` registration, which is required by the autograd saved-variable mechanism. + +--- + +## Immutability Cleanup in C++ + +The library has formally embraced an immutable design on the Python side — grids are created, never mutated. The C++ layer hasn't fully caught up. [`GridBatchImpl`](src/fvdb/detail/GridBatchImpl.h) currently exposes several public mutation methods that are artifacts of an older design: + +| Method | Current callers | After refactor | +|:-------|:----------------|:---------------| +| `setGlobalVoxelSize` | [`GridBatch.cpp`](src/fvdb/GridBatch.cpp) only | **Delete** — `GridBatch.cpp` is removed, no callers remain | +| `setGlobalVoxelOrigin` | `GridBatch.cpp` only | **Delete** | +| `setGlobalPrimalTransform` | `GridBatch.cpp` only | **Delete** | +| `setGlobalDualTransform` | `GridBatch.cpp` only | **Delete** | +| `setGlobalVoxelSizeAndOrigin` | `GridBatch.cpp` only | **Delete** | +| `setFineTransformFromCoarseGrid` | Internal: called by `upsample()` on a freshly-created object | **Make private** | +| `setCoarseTransformFromFineGrid` | Internal: called by `coarsen()` on a freshly-created object | **Make private** | +| `setPrimalTransformFromDualGrid` | Internal: called by `dual()` on a freshly-created object | **Make private** | +| `setGrid` | Internal: construction only | **Make private** | + +These are already inaccessible from the Python API — neither [`grid_batch.py`](fvdb/grid_batch.py) nor [`grid.py`](fvdb/grid.py) expose any of them. The `setGlobal*` methods are bound via pybind in [`GridBatchBinding.cpp`](src/python/GridBatchBinding.cpp) but that binding is deleted along with `GridBatch.cpp`. + +The three `set*TransformFrom*Grid` methods are used internally by `coarsen()`, `upsample()`, and `dual()` to adjust transforms on a freshly-constructed `GridBatchImpl` before returning it. Making them private preserves this internal use while ensuring the public C++ interface is fully const after construction. In Phase 4, if these derived-grid methods are extracted as free functions, the transform adjustment would be folded into the construction path itself, eliminating these setters entirely. + +After this cleanup, `GridBatchImpl`'s public interface is **entirely read-only after construction**: create it, query it, build an accessor from it, index into it, serialize it. No public mutation. This matches the library's design contract and removes a class of potential misuse from the C++ API. + +--- + +## Functional Python API + +> *This section is de-emphasized for the initial pitch — it does not block or complicate the core refactor. It describes a natural follow-on that the refactor enables.* + +The refactor naturally produces a set of Python functions that take a grid and data as arguments and return results — the same pattern as `torch.nn.functional` relative to `torch.nn`. Once the class-based methods in `grid_batch.py` are calling thin C++ functions directly, we can surface those same calls as a public functional API: + +```python +# fvdb/functional.py (new module) + +def sample_trilinear(grid: GridBatch, points: JaggedTensor, voxel_data: JaggedTensor) -> JaggedTensor: + """Sample voxel data at world-space points using trilinear interpolation.""" + if points.ldim() != 1: + raise ValueError("Expected points to have 1 list dimension") + if voxel_data.ldim() != 1: + raise ValueError("Expected voxel_data to have 1 list dimension") + result = _fvdb_cpp.sample_trilinear(grid._grid_batch_impl, points._impl, voxel_data.jdata, False) + return points.jagged_like(result[0]) + +def coarsened_grid(grid: GridBatch, coarsening_factor: NumericMaxRank1) -> GridBatch: + """Return a coarsened copy of the grid.""" + coarsening_factor = to_Vec3iBroadcastable(coarsening_factor, value_constraint=ValueConstraint.POSITIVE) + return GridBatch(impl=grid._grid_batch_impl.coarsen(coarsening_factor.tolist())) +``` + +The class-based API then becomes a thin delegation layer: + +```python +# fvdb/grid_batch.py +class GridBatch: + def sample_trilinear(self, points, voxel_data): + return fvdb.functional.sample_trilinear(self, points, voxel_data) + + def coarsened_grid(self, coarsening_factor): + return fvdb.functional.coarsened_grid(self, coarsening_factor) +``` + +This gives users a choice of style — `grid.sample_trilinear(pts, data)` or `fvdb.functional.sample_trilinear(grid, pts, data)` — without maintaining two implementations. The class methods are one-liners that delegate to the functional versions. The `Grid` single-grid class follows the same pattern. + +This is a natural byproduct of the refactor, not an additional effort: the validation and dispatch logic that moves into Python *is* the functional implementation. The class methods just call it. + +--- + +## Design Principle: Strict in C++, Ergonomic in Python + +This refactor enshrines a clear boundary: + +| Responsibility | Where it lives | +|:---|:---| +| Accept flexible user input (scalars, lists, tensors, mixed types) | **Python** | +| Validate shapes, dtypes, list dimensions, value ranges | **Python** | +| Handle default parameters and optional arguments | **Python** | +| Provide docstrings and type annotations | **Python** | +| Convert to strict, unambiguous types | **Python** | +| Register autograd tape nodes | **C++** (thin function) | +| Dispatch CPU/CUDA template instantiation | **C++** (thin function) | +| Core data storage, accessor, NanoVDB integration | **C++** (`GridBatchImpl`) | +| CUDA kernels and dispatch functions | **C++** (ops layer, unchanged) | + +C++ function signatures become strict and unambiguous — they take exactly the types the kernel needs. No duck-typing, no template metaprogramming for input flexibility, no multi-constructor overload resolution. If a kernel needs `nanovdb::Coord`, the C++ function takes `std::array` or `nanovdb::Coord`. Python has already done the conversion before the call crosses the boundary. + +--- + +## Why This Is Safe + +1. **The public Python API is unchanged.** [`grid_batch.py`](fvdb/grid_batch.py) and [`grid.py`](fvdb/grid.py) present the same classes, methods, properties, and type signatures. User code does not change. + +2. **The ops and autograd layers are untouched.** They already work with `GridBatchImpl` exclusively. The refactor only changes how they are *called*, not what they *do*. + +3. **The viewer is unaffected.** [`Viewer.h`](src/fvdb/detail/viewer/Viewer.h) has one unused include of `GridBatch.h` and no actual dependency on the wrapper class. The viewer team's code, data model, and API are not disrupted. + +4. **The testing infrastructure is already in place.** Python tests exercise the public API end-to-end. Since the API doesn't change, existing tests validate the refactored code paths. The C++ layers being removed are purely intermediary — they have no behavior that isn't already tested through the Python surface. + +5. **`GridBatchImpl` is stable.** It's not being restructured. Its public interface, accessor, memory management, and serialization format are all unchanged. + +--- + +## Why This Can Be Largely Automated + +The wrapper methods in [`GridBatch.cpp`](src/fvdb/GridBatch.cpp) follow a small number of mechanical patterns: + +| Pattern | Count | Transformation | +|:--------|------:|:---------------| +| **Pure forwarding**: `return mImpl->someMethod(args...)` | ~40 | Becomes direct Python call to `GridBatchImpl` binding | +| **Validation + autograd dispatch**: checks + `autograd::Fn::apply(mImpl, ...)` | ~12 | Validation moves to Python; `::apply()` becomes a thin bound function | +| **Validation + device dispatch**: checks + `FVDB_DISPATCH_KERNEL_DEVICE(...)` | ~15 | Same — validation to Python, dispatch stays in thin C++ function | +| **Type conversion + forwarding**: `Vec3iOrScalar` -> `nanovdb::Coord` -> forward | ~20 | C++ conversion deleted; Python's existing `to_Vec3iBroadcastable` is kept | + +Each pattern is identifiable by inspection and transformable by a consistent recipe. The existing test suite serves as the correctness oracle at every step. + +--- + +## Phasing + +The refactor can be done incrementally, with tests passing at every intermediate state. + +### Phase 1: Bind `GridBatchImpl` and thin functions alongside existing `GridBatch` + +Add pybind bindings for `GridBatchImpl` properties/methods and the ~25 thin wrapper functions. Both old and new bindings coexist. Wire up a few Python methods to use the new path. Run tests. + +### Phase 2: Migrate `grid_batch.py` method by method + +For each method in `grid_batch.py`, switch from calling `self._impl.method(...)` (which goes through `GridBatch`) to calling `_fvdb_cpp.thin_function(self._grid_batch_impl, ...)` directly. Move validation from `GridBatch.cpp` to Python where it doesn't already exist. Run tests after each method. + +### Phase 3: Remove `GridBatch.h/cpp` and the `Vec3*` type system + +Once no Python code references the old `GridBatch` bindings, delete the files listed in [Files deleted](#files-deleted). Update [`Bindings.cpp`](src/python/Bindings.cpp) to remove the `GridBatch` class registration. Clean up the 4 ops/autograd files that use `Vec3*OrScalar`/`Vec3*Batch`. Remove the dead include from `Viewer.h`. Run tests. + +### Phase 4: Enforce immutability in `GridBatchImpl` + +Delete the `setGlobal*` public mutation methods from `GridBatchImpl` (no remaining callers after Phase 3). Make `setFineTransformFromCoarseGrid`, `setCoarseTransformFromFineGrid`, `setPrimalTransformFromDualGrid`, and `setGrid` private. This makes `GridBatchImpl`'s public interface fully read-only after construction, matching the library's immutability contract. See [Immutability Cleanup in C++](#immutability-cleanup-in-c) for details. + +### Phase 5: Surface `fvdb.functional` module + +Extract the validation and dispatch logic in `grid_batch.py` into a public `fvdb/functional.py` module of free functions. Rewrite the `GridBatch` and `Grid` class methods as one-liner delegations to the functional versions. This is a natural byproduct — the functional implementations already exist at this point as the method bodies; they just need to be lifted into a module. See [Functional Python API](#functional-python-api) for details. + +### Phase 6 (optional): Extract derived-grid ops from `GridBatchImpl` + +Move `coarsen`, `upsample`, `dual`, `clip`, `dilate`, `merge`, `prune`, `convolutionOutput` from `GridBatchImpl` methods to free functions. Fold the private transform-setter calls into the construction path, eliminating those methods entirely. This further slims the core class but is a lower-priority cleanup. + +--- + +## Metrics + +| Metric | Before | After | Delta | +|:-------|-------:|------:|------:| +| C++ lines in wrapper layer | ~2,585 | 0 | **-2,585** | +| C++ lines in thin functions | 0 | ~200 | +200 | +| Python lines in `grid_batch.py` | ~2,900 | ~3,100 | +200 | +| Total C++ files for GridBatch pipeline | 6 | 2 | **-4 files** | +| Layers in the call chain | 4 | 2 | **-2 layers** | +| C++ template metaprogramming for type flexibility | ~500 lines | 0 | **-500 lines** | +| Places to edit when adding a new operation | 4 files | 2 files | **-2 files** | + +--- + +## Conclusion + +The `GridBatch` wrapper and `Vec3*` type system were reasonable abstractions when the C++ layer was the primary API surface. Now that Python is the API surface — with its own type conversion, validation, docstrings, and static typing — these layers are pure overhead. Every operation is already functional in nature (const input -> new output), the autograd and ops layers already depend only on `GridBatchImpl`, and the Python layer already wraps everything. + +This refactor removes the empty middle, enshrines the boundary between Python ergonomics and C++ strictness, and makes the codebase smaller, faster to compile, and easier to understand — without changing a single user-facing behavior. It also brings the C++ layer into alignment with the library's immutability contract by eliminating public mutation methods that are already inaccessible from Python, and it naturally enables a `torch.nn.functional`-style functional Python API as a first-class alternative to the existing class-based interface.