Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions fvdb/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,6 +1835,40 @@ def has_zero_voxels(self) -> bool:
def ijk(self) -> torch.Tensor:
return self._impl.ijk.jdata

def serialize_encode(self, order_type: str = "z") -> torch.Tensor:
"""
Return the Morton codes for active voxels in this grid.

Morton codes provide a space-filling curve that maps 3D coordinates to 1D integers,
preserving spatial locality. This is useful for serialization, sorting, and
spatial data structures.

Returns:
torch.Tensor: A tensor of shape `[num_voxels, 1]` containing
the Morton codes for each active voxel.
"""
return self._impl.serialize_encode(order_type).jdata

def permute(self, order_type: str = "z") -> torch.Tensor:
"""
Get permutation indices to sort voxels by spatial order.

This method computes Morton codes for all active voxels and returns the indices
that would sort them according to the specified ordering. This is useful for
spatially coherent data access patterns and cache optimization.

Args:
order_type (str): The type of spatial ordering to use:
- "morton": Morton Z-order curve (default, ascending)
- "ascending": Sort Morton codes in ascending order (same as "morton")
- "descending": Sort Morton codes in descending order

Returns:
torch.Tensor: A tensor of shape `[num_voxels, 1]` containing
the permutation indices. Use these indices to reorder voxel data for spatial coherence.
"""
return self._impl.permute(order_type).jdata

@property
def num_bytes(self) -> int:
return self._impl.total_bytes
Expand Down
50 changes: 50 additions & 0 deletions fvdb/grid_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,56 @@ def has_zero_grids(self) -> bool:
def ijk(self) -> JaggedTensor:
return self._impl.ijk

def serialize_encode(self, order_type: str = "z") -> JaggedTensor:
"""
Return the space-filling curve codes for active voxels in this grid batch.

Space-filling curves provide a mapping from 3D coordinates to 1D integers,
preserving spatial locality. This is useful for serialization, sorting, and
spatial data structures.

Args:
order_type (str): The type of ordering to use:
- "z": Regular Z-order curve (xyz bit interleaving)
- "z-trans": Transposed Z-order curve (zyx bit interleaving)
- "hilbert": Regular Hilbert curve (xyz)
- "hilbert-trans": Transposed Hilbert curve (zyx)

Returns:
JaggedTensor: A JaggedTensor of shape `[num_grids, -1, 1]` containing
the space-filling curve codes for each active voxel in the batch.
"""
return self._impl.serialize_encode(order_type)

def permute(self, order_type: str = "z") -> JaggedTensor:
"""
Get permutation indices to sort voxels by spatial order.

This method computes space-filling curve codes for all active voxels and returns the indices
that would sort them according to the specified ordering. This is useful for
spatially coherent data access patterns and cache optimization.

Args:
order_type (str): The type of spatial ordering to use:
- "z": Regular Z-order curve (xyz bit interleaving, default)
- "z-trans": Transposed Z-order curve (zyx bit interleaving)
- "hilbert": Regular Hilbert curve (xyz)
- "hilbert-trans": Transposed Hilbert curve (zyx)

Returns:
JaggedTensor: A JaggedTensor of shape `[num_grids, -1, 1]` containing
the permutation indices. Use these indices to reorder voxel data for spatial coherence.

Example:
>>> z_indices = grid_batch.permute("z") # Regular xyz z-order
>>> z_trans_indices = grid_batch.permute("z-trans") # Transposed zyx z-order
>>> hilbert_indices = grid_batch.permute("hilbert") # Regular xyz Hilbert curve
>>> hilbert_trans_indices = grid_batch.permute("hilbert-trans") # Transposed zyx Hilbert curve
>>> # Use indices to reorder some voxel data
>>> reordered_data = voxel_data.jdata[z_indices.jdata.squeeze(-1)]
"""
return self._impl.permute(order_type)

@property
def jidx(self) -> torch.Tensor:
if self.has_zero_grids:
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ set(FVDB_CU_FILES
fvdb/detail/ops/PointsInGrid.cu
fvdb/detail/ops/IjkToIndex.cu
fvdb/detail/ops/ActiveGridGoords.cu
fvdb/detail/ops/SerializeEncode.cu
fvdb/detail/ops/IjkToInvIndex.cu
fvdb/detail/ops/ReadFromDense.cu
fvdb/detail/ops/NearestIjkForPoints.cu
Expand Down
62 changes: 62 additions & 0 deletions src/fvdb/GridBatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

// Ops headers
#include <fvdb/detail/ops/ActiveGridGoords.h>
#include <fvdb/detail/ops/SerializeEncode.h>
#include <tuple>
#include <vector>
#include <fvdb/detail/ops/CoordsInGrid.h>
#include <fvdb/detail/ops/CubesInGrid.h>
#include <fvdb/detail/ops/GridEdgeNetwork.h>
Expand Down Expand Up @@ -1081,6 +1084,65 @@ GridBatch::ijk() const {
});
}

JaggedTensor
GridBatch::serialize_encode(const std::string &order_type) const {
c10::DeviceGuard guard(device());
return FVDB_DISPATCH_KERNEL(this->device(), [&]() {
return fvdb::detail::ops::dispatchSerializeEncode<DeviceTag>(*mImpl, order_type);
});
}

JaggedTensor
GridBatch::permute(const std::string &order_type) const {
c10::DeviceGuard guard(device());

// Parse order type and determine space-filling curve type and sort order
std::string curve_order = "z"; // Default z-order (xyz)

if (order_type == "z") {
curve_order = "z";
} else if (order_type == "z-trans") {
curve_order = "z-trans";
} else if (order_type == "hilbert") {
curve_order = "hilbert";
} else if (order_type == "hilbert-trans") {
curve_order = "hilbert-trans";
} else {
TORCH_CHECK(false, "Invalid order_type: ", order_type,
". Valid options are 'z', 'z-trans', 'hilbert', or 'hilbert-trans'.");
}

// Get space-filling curve codes for sorting
JaggedTensor curve_codes = serialize_encode(curve_order);

// Create output tensor for permutation indices
auto opts = torch::TensorOptions().dtype(torch::kInt64).device(device());
std::vector<int64_t> shape = {mImpl->totalVoxels()};
torch::Tensor permutation_indices = torch::empty(shape, opts);

// Sort space-filling curve codes and get permutation indices for each grid
int64_t offset = 0;
for (int64_t grid_idx = 0; grid_idx < mImpl->batchSize(); ++grid_idx) {
int64_t num_voxels = mImpl->numVoxelsAt(grid_idx);
if (num_voxels == 0) continue;

// Extract space-filling curve codes for this grid
torch::Tensor grid_curve_codes = curve_codes.jdata().narrow(0, offset, num_voxels);

// Sort and get indices
auto sort_result = torch::sort(grid_curve_codes.squeeze(-1), /*dim=*/0);
torch::Tensor sorted_values = std::get<0>(sort_result);
torch::Tensor indices = std::get<1>(sort_result);

// Store indices with offset
permutation_indices.narrow(0, offset, num_voxels) = indices + offset;

offset += num_voxels;
}

return mImpl->jaggedTensor(permutation_indices.unsqueeze(-1));
}

std::vector<JaggedTensor>
GridBatch::viz_edge_network(bool returnVoxelCoordinates) const {
c10::DeviceGuard guard(device());
Expand Down
18 changes: 18 additions & 0 deletions src/fvdb/GridBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,24 @@ struct GridBatch : torch::CustomClassHolder {
/// @return A JaggedTensor of voxel coordinates indexed by this grid batch (shape [B, -1, 3])
JaggedTensor ijk() const;

/// @brief Return the space-filling curve codes for active voxels in this grid batch with specific order type
/// @param order_type The type of ordering to use:
/// - "z": Regular Z-order curve (xyz bit interleaving)
/// - "z-trans": Transposed Z-order curve (zyx bit interleaving)
/// - "hilbert": Regular Hilbert curve (xyz)
/// - "hilbert-trans": Transposed Hilbert curve (zyx)
/// @return A JaggedTensor of space-filling curve codes for active voxels (shape [B, -1, 1])
JaggedTensor serialize_encode(const std::string &order_type) const;

/// @brief Get permutation indices to sort voxels by spatial order
/// @param order_type The type of spatial ordering to use:
/// - "z": Regular Z-order curve (xyz bit interleaving, default)
/// - "z-trans": Transposed Z-order curve (zyx bit interleaving)
/// - "hilbert": Regular Hilbert curve (xyz)
/// - "hilbert-trans": Transposed Hilbert curve (zyx)
/// @return A JaggedTensor of permutation indices (shape [B, -1])
JaggedTensor permute(const std::string &order_type) const;

/// @brief Find the intersection between a collection of rays and the zero level set of a scalar
/// field
/// at each voxel in the grid batch
Expand Down
Loading
Loading