Skip to content

Commit 08aad1b

Browse files
authored
Active grid coords cleanup (#318)
Used the new SImpleOpHelper to simplify ActiveGridCoords. --------- Signed-off-by: Christopher Horvath <chorvath@nvidia.com>
1 parent 4a6b0e8 commit 08aad1b

File tree

2 files changed

+23
-98
lines changed

2 files changed

+23
-98
lines changed

src/fvdb/detail/ops/ActiveGridGoords.cu

Lines changed: 22 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -3,115 +3,41 @@
33
//
44
#include <fvdb/detail/GridBatchImpl.h>
55
#include <fvdb/detail/ops/ActiveGridGoords.h>
6-
#include <fvdb/detail/utils/AccessorHelpers.cuh>
7-
#include <fvdb/detail/utils/ForEachCPU.h>
8-
#include <fvdb/detail/utils/cuda/ForEachCUDA.cuh>
9-
#include <fvdb/detail/utils/cuda/ForEachPrivateUse1.cuh>
10-
11-
#include <c10/cuda/CUDAException.h>
6+
#include <fvdb/detail/utils/SimpleOpHelper.h>
127

138
namespace fvdb {
149
namespace detail {
1510
namespace ops {
1611

17-
/// @brief Per-voxel callback which computes the active grid coordinates for a batch of grids
18-
template <template <typename T, int32_t D> typename TorchAccessor>
19-
__hostdev__ inline void
20-
activeGridCoordsVoxelCallback(int64_t batchIdx,
21-
int64_t leafIdx,
22-
int64_t voxelIdx,
23-
GridBatchImpl::Accessor gridAccessor,
24-
TorchAccessor<int32_t, 2> outGridCoords) {
25-
const nanovdb::OnIndexGrid *grid = gridAccessor.grid(batchIdx);
26-
const typename nanovdb::OnIndexGrid::LeafNodeType &leaf =
27-
grid->tree().template getFirstNode<0>()[leafIdx];
28-
const int64_t baseOffset = gridAccessor.voxelOffset(batchIdx);
29-
30-
const nanovdb::Coord &ijk = leaf.offsetToGlobalCoord(voxelIdx);
31-
if (leaf.isActive(voxelIdx)) {
32-
const int64_t idx = baseOffset + (int64_t)leaf.getValue(voxelIdx) - 1;
33-
outGridCoords[idx][0] = ijk[0];
34-
outGridCoords[idx][1] = ijk[1];
35-
outGridCoords[idx][2] = ijk[2];
36-
}
37-
}
12+
namespace {
3813

39-
/// @brief Get the active grid coordinates for a batch of grids (including disabled coordinates in
40-
/// mutable grids)
41-
/// @param gridBatch The batch of grids
42-
/// @param outGridCoords Tensor which will contain the output grid coordinates
4314
template <torch::DeviceType DeviceTag>
44-
void
45-
GetActiveGridCoords(const GridBatchImpl &gridBatch, torch::Tensor &outGridCoords) {
46-
auto outCoordsAcc = tensorAccessor<DeviceTag, int32_t, 2>(outGridCoords);
47-
48-
if constexpr (DeviceTag == torch::kCUDA) {
49-
auto cb = [=] __device__(int64_t batchIdx,
50-
int64_t leafIdx,
51-
int64_t voxelIdx,
52-
int64_t,
53-
GridBatchImpl::Accessor gridAccessor) {
54-
activeGridCoordsVoxelCallback<TorchRAcc32>(
55-
batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc);
56-
};
57-
forEachVoxelCUDA(1024, 1, gridBatch, cb);
58-
} else if constexpr (DeviceTag == torch::kPrivateUse1) {
59-
auto cb = [=] __device__(int64_t batchIdx,
60-
int64_t leafIdx,
61-
int64_t voxelIdx,
62-
int64_t,
63-
GridBatchImpl::Accessor gridAccessor) {
64-
activeGridCoordsVoxelCallback<TorchRAcc32>(
65-
batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc);
66-
};
67-
forEachVoxelPrivateUse1(1, gridBatch, cb);
68-
} else {
69-
auto cb = [=](int64_t batchIdx,
70-
int64_t leafIdx,
71-
int64_t voxelIdx,
72-
int64_t,
73-
GridBatchImpl::Accessor gridAccessor) {
74-
activeGridCoordsVoxelCallback<TorchAcc>(
75-
batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc);
76-
};
77-
forEachVoxelCPU(1, gridBatch, cb);
15+
struct Processor : public BaseProcessor<DeviceTag, Processor<DeviceTag>, int32_t, 3> {
16+
// active coords get saved directly to the output tensor
17+
__hostdev__ void
18+
perActiveVoxel(nanovdb::Coord const &ijk, int64_t const feature_idx, auto out_accessor) const {
19+
auto const i = static_cast<int32_t>(ijk[0]);
20+
auto const j = static_cast<int32_t>(ijk[1]);
21+
auto const k = static_cast<int32_t>(ijk[2]);
22+
23+
auto &&out = out_accessor[feature_idx];
24+
out[0] = i;
25+
out[1] = j;
26+
out[2] = k;
7827
}
79-
}
28+
};
8029

81-
/// @brief Get the number of active (or enabled for mutable grids) ijk coordinates in a batch of
82-
/// grids
83-
/// @tparam DeviceTag Which device to run on
84-
/// @param gridBatch The batch of grids to get the active coordinates for
85-
/// @param ignoreDisabledVoxels If set to true, and the grid batch is mutable, also return
86-
/// coordinates that are disabled
87-
/// @return A JaggedTensor or shape [B, -1, 3] of active/enabled IJK coordinates
88-
template <torch::DeviceType DeviceTag>
89-
JaggedTensor
90-
ActiveGridCoords(const GridBatchImpl &gridBatch) {
91-
gridBatch.checkNonEmptyGrid();
92-
auto opts = torch::TensorOptions().dtype(torch::kInt32).device(gridBatch.device());
93-
torch::Tensor outGridCoords = torch::empty({gridBatch.totalVoxels(), 3}, opts);
94-
GetActiveGridCoords<DeviceTag>(gridBatch, outGridCoords);
95-
return gridBatch.jaggedTensor(outGridCoords);
96-
}
97-
98-
template <>
99-
JaggedTensor
100-
dispatchActiveGridCoords<torch::kCUDA>(const GridBatchImpl &gridBatch) {
101-
return ActiveGridCoords<torch::kCUDA>(gridBatch);
102-
}
30+
} // End anonymous namespace
10331

104-
template <>
32+
template <torch::DeviceType DeviceTag>
10533
JaggedTensor
106-
dispatchActiveGridCoords<torch::kCPU>(const GridBatchImpl &gridBatch) {
107-
return ActiveGridCoords<torch::kCPU>(gridBatch);
34+
dispatchActiveGridCoords(GridBatchImpl const &gridBatch) {
35+
return Processor<DeviceTag>{}.execute(gridBatch);
10836
}
10937

110-
template <>
111-
JaggedTensor
112-
dispatchActiveGridCoords<torch::kPrivateUse1>(const GridBatchImpl &gridBatch) {
113-
return ActiveGridCoords<torch::kPrivateUse1>(gridBatch);
114-
}
38+
template JaggedTensor dispatchActiveGridCoords<torch::kCUDA>(GridBatchImpl const &);
39+
template JaggedTensor dispatchActiveGridCoords<torch::kCPU>(GridBatchImpl const &);
40+
template JaggedTensor dispatchActiveGridCoords<torch::kPrivateUse1>(GridBatchImpl const &);
11541

11642
} // namespace ops
11743
} // namespace detail

src/fvdb/detail/ops/ActiveGridGoords.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ namespace fvdb {
1313
namespace detail {
1414
namespace ops {
1515

16-
template <torch::DeviceType>
17-
JaggedTensor dispatchActiveGridCoords(const GridBatchImpl &gridAccessor);
16+
template <torch::DeviceType> JaggedTensor dispatchActiveGridCoords(GridBatchImpl const &gridBatch);
1817

1918
} // namespace ops
2019
} // namespace detail

0 commit comments

Comments
 (0)