@@ -54,22 +54,22 @@ static constexpr size_t CacheSize = RoundUp( ( BatchBuilderMaxBlockSize - 1 ) *
5454 RoundUp ( ( BatchBuilderMaxBlockSize ) * sizeof( ReferenceNode ), CacheAlignment ) +
5555 2 * RoundUp( BatchBuilderMaxBlockSize * sizeof ( uint32_t ), CacheAlignment ) +
5656 RoundUp( BatchBuilderMaxBlockSize * sizeof ( uint32_t ), CacheAlignment ) +
57- RoundUp( BatchBuilderMaxBlockSize * sizeof ( int2 ), CacheAlignment );
57+ RoundUp( BatchBuilderMaxBlockSize * sizeof ( int3 ), CacheAlignment );
5858
5959HIPRT_DEVICE size_t getStorageBufferSize ( const hiprtGeometryBuildInput& buildInput )
6060{
61- const size_t primCount = getPrimCount ( buildInput );
62- const size_t nodeSize = getNodeSize ( buildInput );
63- const size_t nodeCount = divideRoundUp ( 2 * primCount, 3 );
64- return getGeometryStorageBufferSize ( primCount, nodeCount, nodeSize );
61+ const size_t primCount = getPrimCount ( buildInput );
62+ const size_t primNodeSize = getPrimNodeSize ( buildInput );
63+ const size_t boxNodeCount = divideRoundUp ( 2 * primCount, 3 );
64+ return getGeometryStorageBufferSize ( primCount, boxNodeCount, primNodeSize );
6565}
6666
6767HIPRT_DEVICE size_t getStorageBufferSize ( const hiprtSceneBuildInput& buildInput )
6868{
69- const size_t frameCount = buildInput.frameCount ;
70- const size_t primCount = buildInput.instanceCount ;
71- const size_t nodeCount = divideRoundUp ( 2 * primCount, 3 );
72- return getSceneStorageBufferSize ( primCount, nodeCount , frameCount );
69+ const size_t frameCount = buildInput.frameCount ;
70+ const size_t primCount = buildInput.instanceCount ;
71+ const size_t boxNodeCount = divideRoundUp ( 2 * primCount, 3 );
72+ return getSceneStorageBufferSize ( primCount, primCount, boxNodeCount , frameCount );
7373}
7474
7575template <typename PrimitiveNode, typename PrimitiveContainer>
@@ -88,23 +88,12 @@ build( PrimitiveContainer& primitives, uint32_t geomType, MemoryArena& storageMe
8888 // STEP 0: Init data
8989 if constexpr ( is_same<Header, SceneHeader>::value )
9090 {
91- Instance* instances = storageMemoryArena.allocate <Instance>( primitives.getCount () );
92- uint32_t * masks = storageMemoryArena.allocate <uint32_t >( primitives.getCount () );
93- hiprtTransformHeader* transforms = storageMemoryArena.allocate <hiprtTransformHeader>( primitives.getCount () );
94- Frame* frames = storageMemoryArena.allocate <Frame>( primitives.getFrameCount () );
91+ Frame* frames = storageMemoryArena.allocate <Frame>( primitives.getFrameCount () );
92+ Instance* instances = storageMemoryArena.allocate <Instance>( primitives.getCount () );
9593
9694 primitives.setFrames ( frames );
9795 InitSceneData<>(
98- index,
99- storageMemoryArena.getStorageSize (),
100- primitives,
101- boxNodes,
102- primNodes,
103- instances,
104- masks,
105- transforms,
106- frames,
107- header );
96+ index, storageMemoryArena.getStorageSize (), primitives, boxNodes, primNodes, instances, frames, header );
10897 }
10998 else
11099 {
@@ -133,7 +122,7 @@ build( PrimitiveContainer& primitives, uint32_t geomType, MemoryArena& storageMe
133122 uint32_t * mortonCodeKeys = sharedMemoryArena.allocate <uint32_t >( blockDim.x );
134123 uint32_t * mortonCodeValues = sharedMemoryArena.allocate <uint32_t >( blockDim.x );
135124 uint32_t * updateCounters = sharedMemoryArena.allocate <uint32_t >( blockDim.x );
136- int2 * taskQueue = sharedMemoryArena.allocate <int2 >( blockDim.x );
125+ int3 * taskQueue = sharedMemoryArena.allocate <int3 >( blockDim.x );
137126
138127 // STEP 1: Calculate centroid bounding box by reduction
139128 updateCounters[index] = InvalidValue;
@@ -173,27 +162,21 @@ build( PrimitiveContainer& primitives, uint32_t geomType, MemoryArena& storageMe
173162 }
174163
175164 // STEP 4: Emit topology and refit nodes
176- EmitTopologyAndFitBounds (
177- index, mortonCodeKeys, mortonCodeValues, updateCounters, primitives, scratchNodes, references, primNodes );
165+ EmitTopologyAndFitBounds ( index, mortonCodeKeys, mortonCodeValues, updateCounters, primitives, scratchNodes, references );
178166 __syncthreads ();
179167
180168 // STEP 5: Collapse
181169 uint32_t rootAddr = updateCounters[primCount - 1 ];
182- if ( index == 0 ) taskQueue[0 ] = make_int2 ( rootAddr, InvalidValue );
170+ if ( index == 0 )
171+ taskQueue[index] = make_int3 ( encodeNodeIndex ( rootAddr, BoxType ), 0 , 0 );
172+ else
173+ taskQueue[index] = make_int3 ( InvalidValue, InvalidValue, InvalidValue );
183174 __syncthreads ();
184175
185- uint32_t taskCount = 1 ;
186- uint32_t taskOffset = 0 ;
187- while ( taskCount > 0 )
188- {
189- DeviceCollapse ( index, taskCount, taskOffset, header, scratchNodes, references, boxNodes, primNodes, taskQueue );
190- __syncthreads ();
191-
192- uint32_t nodeCount = header->m_boxNodeCount ;
193- taskOffset += taskCount;
194- taskCount = nodeCount - taskOffset;
195- __syncthreads ();
196- }
176+ uint32_t * taskCounter = &updateCounters[0 ];
177+ *taskCounter = 1 ;
178+ __syncthreads ();
179+ Collapse ( index, primCount, header, scratchNodes, references, boxNodes, primNodes, primitives, taskCounter, taskQueue );
197180}
198181
199182extern " C" __global__ void
0 commit comments