Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 22 additions & 96 deletions src/fvdb/detail/ops/ActiveGridGoords.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,115 +3,41 @@
//
#include <fvdb/detail/GridBatchImpl.h>
#include <fvdb/detail/ops/ActiveGridGoords.h>
#include <fvdb/detail/utils/AccessorHelpers.cuh>
#include <fvdb/detail/utils/ForEachCPU.h>
#include <fvdb/detail/utils/cuda/ForEachCUDA.cuh>
#include <fvdb/detail/utils/cuda/ForEachPrivateUse1.cuh>

#include <c10/cuda/CUDAException.h>
#include <fvdb/detail/utils/SimpleOpHelper.h>

namespace fvdb {
namespace detail {
namespace ops {

/// @brief Per-voxel callback which computes the active grid coordinates for a batch of grids
template <template <typename T, int32_t D> typename TorchAccessor>
__hostdev__ inline void
activeGridCoordsVoxelCallback(int64_t batchIdx,
int64_t leafIdx,
int64_t voxelIdx,
GridBatchImpl::Accessor gridAccessor,
TorchAccessor<int32_t, 2> outGridCoords) {
const nanovdb::OnIndexGrid *grid = gridAccessor.grid(batchIdx);
const typename nanovdb::OnIndexGrid::LeafNodeType &leaf =
grid->tree().template getFirstNode<0>()[leafIdx];
const int64_t baseOffset = gridAccessor.voxelOffset(batchIdx);

const nanovdb::Coord &ijk = leaf.offsetToGlobalCoord(voxelIdx);
if (leaf.isActive(voxelIdx)) {
const int64_t idx = baseOffset + (int64_t)leaf.getValue(voxelIdx) - 1;
outGridCoords[idx][0] = ijk[0];
outGridCoords[idx][1] = ijk[1];
outGridCoords[idx][2] = ijk[2];
}
}
namespace {

/// @brief Get the active grid coordinates for a batch of grids (including disabled coordinates in
/// mutable grids)
/// @param gridBatch The batch of grids
/// @param outGridCoords Tensor which will contain the output grid coordinates
template <torch::DeviceType DeviceTag>
void
GetActiveGridCoords(const GridBatchImpl &gridBatch, torch::Tensor &outGridCoords) {
auto outCoordsAcc = tensorAccessor<DeviceTag, int32_t, 2>(outGridCoords);

if constexpr (DeviceTag == torch::kCUDA) {
auto cb = [=] __device__(int64_t batchIdx,
int64_t leafIdx,
int64_t voxelIdx,
int64_t,
GridBatchImpl::Accessor gridAccessor) {
activeGridCoordsVoxelCallback<TorchRAcc32>(
batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc);
};
forEachVoxelCUDA(1024, 1, gridBatch, cb);
} else if constexpr (DeviceTag == torch::kPrivateUse1) {
auto cb = [=] __device__(int64_t batchIdx,
int64_t leafIdx,
int64_t voxelIdx,
int64_t,
GridBatchImpl::Accessor gridAccessor) {
activeGridCoordsVoxelCallback<TorchRAcc32>(
batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc);
};
forEachVoxelPrivateUse1(1, gridBatch, cb);
} else {
auto cb = [=](int64_t batchIdx,
int64_t leafIdx,
int64_t voxelIdx,
int64_t,
GridBatchImpl::Accessor gridAccessor) {
activeGridCoordsVoxelCallback<TorchAcc>(
batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc);
};
forEachVoxelCPU(1, gridBatch, cb);
struct Processor : public BaseProcessor<DeviceTag, Processor<DeviceTag>, int32_t, 3> {
// active coords get saved directly to the output tensor
__hostdev__ void
perActiveVoxel(nanovdb::Coord const &ijk, int64_t const feature_idx, auto out_accessor) const {
auto const i = static_cast<int32_t>(ijk[0]);
auto const j = static_cast<int32_t>(ijk[1]);
auto const k = static_cast<int32_t>(ijk[2]);

auto &&out = out_accessor[feature_idx];
out[0] = i;
out[1] = j;
out[2] = k;
}
}
};

/// @brief Get the number of active (or enabled for mutable grids) ijk coordinates in a batch of
/// grids
/// @tparam DeviceTag Which device to run on
/// @param gridBatch The batch of grids to get the active coordinates for
/// @param ignoreDisabledVoxels If set to true, and the grid batch is mutable, also return
/// coordinates that are disabled
/// @return A JaggedTensor or shape [B, -1, 3] of active/enabled IJK coordinates
template <torch::DeviceType DeviceTag>
JaggedTensor
ActiveGridCoords(const GridBatchImpl &gridBatch) {
gridBatch.checkNonEmptyGrid();
auto opts = torch::TensorOptions().dtype(torch::kInt32).device(gridBatch.device());
torch::Tensor outGridCoords = torch::empty({gridBatch.totalVoxels(), 3}, opts);
GetActiveGridCoords<DeviceTag>(gridBatch, outGridCoords);
return gridBatch.jaggedTensor(outGridCoords);
}

template <>
JaggedTensor
dispatchActiveGridCoords<torch::kCUDA>(const GridBatchImpl &gridBatch) {
return ActiveGridCoords<torch::kCUDA>(gridBatch);
}
} // End anonymous namespace

template <>
template <torch::DeviceType DeviceTag>
JaggedTensor
dispatchActiveGridCoords<torch::kCPU>(const GridBatchImpl &gridBatch) {
return ActiveGridCoords<torch::kCPU>(gridBatch);
dispatchActiveGridCoords(GridBatchImpl const &gridBatch) {
return Processor<DeviceTag>{}.execute(gridBatch);
}

template <>
JaggedTensor
dispatchActiveGridCoords<torch::kPrivateUse1>(const GridBatchImpl &gridBatch) {
return ActiveGridCoords<torch::kPrivateUse1>(gridBatch);
}
template JaggedTensor dispatchActiveGridCoords<torch::kCUDA>(GridBatchImpl const &);
template JaggedTensor dispatchActiveGridCoords<torch::kCPU>(GridBatchImpl const &);
template JaggedTensor dispatchActiveGridCoords<torch::kPrivateUse1>(GridBatchImpl const &);

} // namespace ops
} // namespace detail
Expand Down
3 changes: 1 addition & 2 deletions src/fvdb/detail/ops/ActiveGridGoords.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ namespace fvdb {
namespace detail {
namespace ops {

template <torch::DeviceType>
JaggedTensor dispatchActiveGridCoords(const GridBatchImpl &gridAccessor);
template <torch::DeviceType> JaggedTensor dispatchActiveGridCoords(GridBatchImpl const &gridBatch);

} // namespace ops
} // namespace detail
Expand Down
Loading