|
3 | 3 | // |
4 | 4 | #include <fvdb/detail/GridBatchImpl.h> |
5 | 5 | #include <fvdb/detail/ops/ActiveGridGoords.h> |
6 | | -#include <fvdb/detail/utils/AccessorHelpers.cuh> |
7 | | -#include <fvdb/detail/utils/ForEachCPU.h> |
8 | | -#include <fvdb/detail/utils/cuda/ForEachCUDA.cuh> |
9 | | -#include <fvdb/detail/utils/cuda/ForEachPrivateUse1.cuh> |
10 | | - |
11 | | -#include <c10/cuda/CUDAException.h> |
| 6 | +#include <fvdb/detail/utils/SimpleOpHelper.h> |
12 | 7 |
|
13 | 8 | namespace fvdb { |
14 | 9 | namespace detail { |
15 | 10 | namespace ops { |
16 | 11 |
|
17 | | -/// @brief Per-voxel callback which computes the active grid coordinates for a batch of grids |
18 | | -template <template <typename T, int32_t D> typename TorchAccessor> |
19 | | -__hostdev__ inline void |
20 | | -activeGridCoordsVoxelCallback(int64_t batchIdx, |
21 | | - int64_t leafIdx, |
22 | | - int64_t voxelIdx, |
23 | | - GridBatchImpl::Accessor gridAccessor, |
24 | | - TorchAccessor<int32_t, 2> outGridCoords) { |
25 | | - const nanovdb::OnIndexGrid *grid = gridAccessor.grid(batchIdx); |
26 | | - const typename nanovdb::OnIndexGrid::LeafNodeType &leaf = |
27 | | - grid->tree().template getFirstNode<0>()[leafIdx]; |
28 | | - const int64_t baseOffset = gridAccessor.voxelOffset(batchIdx); |
29 | | - |
30 | | - const nanovdb::Coord &ijk = leaf.offsetToGlobalCoord(voxelIdx); |
31 | | - if (leaf.isActive(voxelIdx)) { |
32 | | - const int64_t idx = baseOffset + (int64_t)leaf.getValue(voxelIdx) - 1; |
33 | | - outGridCoords[idx][0] = ijk[0]; |
34 | | - outGridCoords[idx][1] = ijk[1]; |
35 | | - outGridCoords[idx][2] = ijk[2]; |
36 | | - } |
37 | | -} |
| 12 | +namespace { |
38 | 13 |
|
39 | | -/// @brief Get the active grid coordinates for a batch of grids (including disabled coordinates in |
40 | | -/// mutable grids) |
41 | | -/// @param gridBatch The batch of grids |
42 | | -/// @param outGridCoords Tensor which will contain the output grid coordinates |
43 | 14 | template <torch::DeviceType DeviceTag> |
44 | | -void |
45 | | -GetActiveGridCoords(const GridBatchImpl &gridBatch, torch::Tensor &outGridCoords) { |
46 | | - auto outCoordsAcc = tensorAccessor<DeviceTag, int32_t, 2>(outGridCoords); |
47 | | - |
48 | | - if constexpr (DeviceTag == torch::kCUDA) { |
49 | | - auto cb = [=] __device__(int64_t batchIdx, |
50 | | - int64_t leafIdx, |
51 | | - int64_t voxelIdx, |
52 | | - int64_t, |
53 | | - GridBatchImpl::Accessor gridAccessor) { |
54 | | - activeGridCoordsVoxelCallback<TorchRAcc32>( |
55 | | - batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc); |
56 | | - }; |
57 | | - forEachVoxelCUDA(1024, 1, gridBatch, cb); |
58 | | - } else if constexpr (DeviceTag == torch::kPrivateUse1) { |
59 | | - auto cb = [=] __device__(int64_t batchIdx, |
60 | | - int64_t leafIdx, |
61 | | - int64_t voxelIdx, |
62 | | - int64_t, |
63 | | - GridBatchImpl::Accessor gridAccessor) { |
64 | | - activeGridCoordsVoxelCallback<TorchRAcc32>( |
65 | | - batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc); |
66 | | - }; |
67 | | - forEachVoxelPrivateUse1(1, gridBatch, cb); |
68 | | - } else { |
69 | | - auto cb = [=](int64_t batchIdx, |
70 | | - int64_t leafIdx, |
71 | | - int64_t voxelIdx, |
72 | | - int64_t, |
73 | | - GridBatchImpl::Accessor gridAccessor) { |
74 | | - activeGridCoordsVoxelCallback<TorchAcc>( |
75 | | - batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc); |
76 | | - }; |
77 | | - forEachVoxelCPU(1, gridBatch, cb); |
| 15 | +struct Processor : public BaseProcessor<DeviceTag, Processor<DeviceTag>, int32_t, 3> { |
| 16 | + // active coords get saved directly to the output tensor |
| 17 | + __hostdev__ void |
| 18 | + perActiveVoxel(nanovdb::Coord const &ijk, int64_t const feature_idx, auto out_accessor) const { |
| 19 | + auto const i = static_cast<int32_t>(ijk[0]); |
| 20 | + auto const j = static_cast<int32_t>(ijk[1]); |
| 21 | + auto const k = static_cast<int32_t>(ijk[2]); |
| 22 | + |
| 23 | + auto &&out = out_accessor[feature_idx]; |
| 24 | + out[0] = i; |
| 25 | + out[1] = j; |
| 26 | + out[2] = k; |
78 | 27 | } |
79 | | -} |
| 28 | +}; |
80 | 29 |
|
81 | | -/// @brief Get the number of active (or enabled for mutable grids) ijk coordinates in a batch of |
82 | | -/// grids |
83 | | -/// @tparam DeviceTag Which device to run on |
84 | | -/// @param gridBatch The batch of grids to get the active coordinates for |
85 | | -/// @param ignoreDisabledVoxels If set to true, and the grid batch is mutable, also return |
86 | | -/// coordinates that are disabled |
87 | | -/// @return A JaggedTensor or shape [B, -1, 3] of active/enabled IJK coordinates |
88 | | -template <torch::DeviceType DeviceTag> |
89 | | -JaggedTensor |
90 | | -ActiveGridCoords(const GridBatchImpl &gridBatch) { |
91 | | - gridBatch.checkNonEmptyGrid(); |
92 | | - auto opts = torch::TensorOptions().dtype(torch::kInt32).device(gridBatch.device()); |
93 | | - torch::Tensor outGridCoords = torch::empty({gridBatch.totalVoxels(), 3}, opts); |
94 | | - GetActiveGridCoords<DeviceTag>(gridBatch, outGridCoords); |
95 | | - return gridBatch.jaggedTensor(outGridCoords); |
96 | | -} |
97 | | - |
98 | | -template <> |
99 | | -JaggedTensor |
100 | | -dispatchActiveGridCoords<torch::kCUDA>(const GridBatchImpl &gridBatch) { |
101 | | - return ActiveGridCoords<torch::kCUDA>(gridBatch); |
102 | | -} |
| 30 | +} // End anonymous namespace |
103 | 31 |
|
104 | | -template <> |
| 32 | +template <torch::DeviceType DeviceTag> |
105 | 33 | JaggedTensor |
106 | | -dispatchActiveGridCoords<torch::kCPU>(const GridBatchImpl &gridBatch) { |
107 | | - return ActiveGridCoords<torch::kCPU>(gridBatch); |
| 34 | +dispatchActiveGridCoords(GridBatchImpl const &gridBatch) { |
| 35 | + return Processor<DeviceTag>{}.execute(gridBatch); |
108 | 36 | } |
109 | 37 |
|
110 | | -template <> |
111 | | -JaggedTensor |
112 | | -dispatchActiveGridCoords<torch::kPrivateUse1>(const GridBatchImpl &gridBatch) { |
113 | | - return ActiveGridCoords<torch::kPrivateUse1>(gridBatch); |
114 | | -} |
| 38 | +template JaggedTensor dispatchActiveGridCoords<torch::kCUDA>(GridBatchImpl const &); |
| 39 | +template JaggedTensor dispatchActiveGridCoords<torch::kCPU>(GridBatchImpl const &); |
| 40 | +template JaggedTensor dispatchActiveGridCoords<torch::kPrivateUse1>(GridBatchImpl const &); |
115 | 41 |
|
116 | 42 | } // namespace ops |
117 | 43 | } // namespace detail |
|
0 commit comments