@@ -35,61 +35,83 @@ struct BVHNode
35
35
end
36
36
37
37
abstract type LinearNode end
38
- struct LinearBVHLeaf <: LinearNode
38
+
39
+ struct LinearBVH <: LinearNode
39
40
bounds:: Bounds3
40
- primitives_offset :: UInt32
41
+ offset :: UInt32
41
42
n_primitives:: UInt32
42
- end
43
- struct LinearBVHInterior <: LinearNode
44
- bounds:: Bounds3
45
- second_child_offset:: UInt32
46
43
split_axis:: UInt8
44
+ is_interior:: Bool
47
45
end
48
- const LinearBVH = Union{LinearBVHLeaf,LinearBVHInterior}
49
46
50
- struct BVHAccel{P <: Primitive } <: AccelPrimitive
51
- primitives:: Vector{P}
52
- max_node_primitives:: UInt8
53
- nodes:: Vector{LinearBVH}
54
- nodes_to_visit:: Vector{Vector{Int32}}
47
+ function LinearBVHLeaf (bounds:: Bounds3 , primitives_offset:: Integer , n_primitives:: Integer )
48
+ LinearBVH (bounds, primitives_offset, n_primitives, 0 , false )
49
+ end
50
+ function LinearBVHInterior (bounds:: Bounds3 , second_child_offset:: Integer , split_axis:: Integer )
51
+ LinearBVH (bounds, second_child_offset, 0 , split_axis, true )
52
+ end
55
53
56
- function BVHAccel (
57
- primitives:: Vector{P} , max_node_primitives:: Integer = 1 ,
58
- ) where P<: Primitive
59
- max_node_primitives = min (255 , max_node_primitives)
60
- isempty (primitives) && return new {P} (primitives, max_node_primitives)
61
- nodes_to_visit = [zeros (Int32, 64 ) for _ in 1 : Threads. maxthreadid ()]
62
- primitives_info = [
63
- BVHPrimitiveInfo (i, world_bound (p))
64
- for (i, p) in enumerate (primitives)
65
- ]
54
+ function primitives_to_bvh (primitives, max_node_primitives= 1 )
55
+ max_node_primitives = min (255 , max_node_primitives)
56
+ isempty (primitives) && return (primitives, max_node_primitives, LinearBVH[])
57
+ primitives_info = [
58
+ BVHPrimitiveInfo (i, world_bound (p))
59
+ for (i, p) in enumerate (primitives)
60
+ ]
61
+ total_nodes = Ref (0 )
62
+ ordered_primitives = similar (primitives, 0 )
63
+ root = _init (
64
+ primitives, primitives_info, 1 , length (primitives),
65
+ total_nodes, ordered_primitives, max_node_primitives,
66
+ )
66
67
67
- total_nodes = Ref ( 0 )
68
- ordered_primitives = P[]
69
- root = _init (
70
- primitives, primitives_info, 1 , length (primitives),
71
- total_nodes, ordered_primitives, max_node_primitives,
72
- )
68
+ offset = Ref {UInt32} ( 1 )
69
+ flattened = Vector {LinearBVH} (undef, total_nodes[])
70
+ _unroll (flattened, root, offset)
71
+ @real_assert total_nodes[] + 1 == offset[]
72
+ return ( ordered_primitives, max_node_primitives, flattened)
73
+ end
73
74
74
- offset = Ref {UInt32} (1 )
75
- flattened = Vector {LinearBVH} (undef, total_nodes[])
76
- _unroll (flattened, root, offset)
77
- @real_assert total_nodes[] + 1 == offset[]
75
+ struct BVHAccel{
76
+ PVec <: AbstractVector ,
77
+ MatVec <: AbstractVector{<:Material} ,
78
+ NodeVec <: AbstractVector{LinearBVH}
79
+ } <: AccelPrimitive
80
+ primitives:: PVec
81
+ materials:: MatVec
82
+ max_node_primitives:: UInt8
83
+ nodes:: NodeVec
84
+ end
78
85
79
- new {P} (ordered_primitives, max_node_primitives, flattened, nodes_to_visit)
86
+ function BVHAccel (
87
+ primitives:: AbstractVector{P} , max_node_primitives:: Integer = 1 ,
88
+ ) where {P}
89
+ materials = map (x-> x. material, primitives)
90
+ meshes = map (x-> x. shape, primitives)
91
+ triangles = Triangle[]
92
+ for (mi, m) in enumerate (meshes)
93
+ vertices = m. vertices
94
+ for i in 1 : div (length (m. indices), 3 )
95
+ push! (triangles, Triangle (m, i, mi))
96
+ end
80
97
end
98
+ ordered_primitives, max_prim, nodes = primitives_to_bvh (triangles, max_node_primitives)
99
+ return BVHAccel (ordered_primitives, materials, UInt8 (max_prim), nodes)
81
100
end
82
101
102
+
103
+
83
104
mutable struct BucketInfo
84
105
count:: UInt32
85
106
bounds:: Bounds3
86
107
end
87
108
88
109
function _init (
89
- primitives:: Vector{P} , primitives_info:: Vector{BVHPrimitiveInfo} ,
90
- from:: Integer , to:: Integer , total_nodes:: Ref{Int64} ,
91
- ordered_primitives:: Vector{P} , max_node_primitives:: Integer ,
92
- ) where P<: Primitive
110
+ primitives:: AbstractVector , primitives_info:: Vector{BVHPrimitiveInfo} ,
111
+ from:: Integer , to:: Integer , total_nodes:: Ref{Int64} ,
112
+ ordered_primitives:: AbstractVector , max_node_primitives:: Integer ,
113
+ )
114
+
93
115
total_nodes[] += 1
94
116
n_primitives = to - from + 1
95
117
# Compute bounds for all primitives in BVH node.
@@ -186,8 +208,9 @@ function _init(
186
208
end
187
209
188
210
function _unroll (
189
- linear_nodes:: Vector{LinearBVH} , node:: BVHNode , offset:: Ref{UInt32} ,
190
- )
211
+ linear_nodes:: Vector{LinearBVH} , node:: BVHNode , offset:: Ref{UInt32} ,
212
+ )
213
+
191
214
l_offset = offset[]
192
215
offset[] += 1
193
216
@@ -210,29 +233,51 @@ end
210
233
length (bvh. nodes) > 0 ? bvh. nodes[1 ]. bounds : Bounds3 ()
211
234
end
212
235
213
- function intersect! (bvh:: BVHAccel{P} , ray:: AbstractRay ):: Tuple{Bool,P,SurfaceInteraction} where {P}
236
+ macro ntuple (N, value)
237
+ expr = :(())
238
+ for i in 1 : N
239
+ push! (expr. args, :($ (esc (value))))
240
+ end
241
+ return expr
242
+ end
243
+
244
+ macro setindex (N, setindex_expr)
245
+ @assert Meta. isexpr (setindex_expr, :(= ))
246
+ index_expr = setindex_expr. args[1 ]
247
+ @assert Meta. isexpr (index_expr, :ref )
248
+ tuple = index_expr. args[1 ]
249
+ idx = index_expr. args[2 ]
250
+ value = setindex_expr. args[2 ]
251
+ expr = :(())
252
+ for i in 1 : N
253
+ push! (expr. args, :(ifelse ($ i != $ (esc (idx)), $ (esc (tuple))[$ i], $ (esc (value)))))
254
+ end
255
+ return :($ (esc (tuple)) = $ expr)
256
+ end
257
+
258
+ @inline function intersect! (bvh:: BVHAccel{P} , ray:: AbstractRay ) where {P}
214
259
hit = false
215
260
interaction = SurfaceInteraction ()
216
- isempty (bvh. nodes) && return hit, nothing , interaction
217
261
218
262
ray = check_direction (ray)
219
263
inv_dir = 1f0 ./ ray. d
220
264
dir_is_neg = is_dir_negative (ray. d)
221
265
222
- to_visit_offset:: Int32 , current_node_i:: Int32 = 1 , 1
223
- @inbounds nodes_to_visit = bvh. nodes_to_visit[Threads. threadid ()]
224
- @inbounds for i in eachindex (nodes_to_visit)
225
- nodes_to_visit[i] = Int32 (0 )
226
- end
227
- primitive:: P = first (bvh. primitives)
228
- primitives = bvh. primitives:: Vector{P}
266
+ to_visit_offset, current_node_i = Int32 (1 ), Int32 (1 )
267
+ # Tuple version is 2us slower, which makes the total rendering time go from 5s to 7s -.-s
268
+ # no other way to do this on the GPU though, is there?
269
+ nodes_to_visit = @ntuple 64 Int32 (0 )
270
+ # nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
271
+ primitives = bvh. primitives
272
+ primitive = first (primitives)
273
+ nodes = bvh. nodes
229
274
@inbounds while true
230
- ln = bvh . nodes[current_node_i]
275
+ ln = nodes[current_node_i]
231
276
if intersect_p (ln. bounds, ray, inv_dir, dir_is_neg)
232
- if ln isa LinearBVHLeaf && ln. n_primitives > 0
277
+ if ! (ln . is_interior) && ln. n_primitives > 0
233
278
# Intersect ray with primitives in node.
234
279
for i in 0 : ln. n_primitives- 1
235
- tmp_primitive:: P = primitives[ln. primitives_offset + i]
280
+ tmp_primitive = primitives[ln. offset + i]
236
281
tmp_hit, ray, tmp_interaction = intersect_p! (
237
282
tmp_primitive, ray,
238
283
)
@@ -247,10 +292,12 @@ function intersect!(bvh::BVHAccel{P}, ray::AbstractRay)::Tuple{Bool,P,SurfaceInt
247
292
current_node_i = nodes_to_visit[to_visit_offset]
248
293
else
249
294
if dir_is_neg[ln. split_axis] == 2
250
- nodes_to_visit[to_visit_offset] = current_node_i + Int32 (1 )
251
- current_node_i = ln. second_child_offset
295
+ @setindex 64 nodes_to_visit[to_visit_offset] = Int32 (current_node_i + 1 )
296
+ # nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
297
+ current_node_i = Int32 (ln. offset)
252
298
else
253
- nodes_to_visit[to_visit_offset] = ln. second_child_offset
299
+ @setindex 64 nodes_to_visit[to_visit_offset] = Int32 (ln. offset)
300
+ # nodes_to_visit[to_visit_offset] = Int32(ln.offset)
254
301
current_node_i += Int32 (1 )
255
302
end
256
303
to_visit_offset += Int32 (1 )
@@ -264,43 +311,45 @@ function intersect!(bvh::BVHAccel{P}, ray::AbstractRay)::Tuple{Bool,P,SurfaceInt
264
311
return hit, primitive, interaction
265
312
end
266
313
267
- function intersect_p (bvh:: BVHAccel , ray:: AbstractRay )
314
+ @inline function intersect_p (bvh:: BVHAccel , ray:: AbstractRay )
315
+
268
316
length (bvh. nodes) == 0 && return false
269
317
270
318
ray = check_direction (ray)
271
319
inv_dir = 1f0 ./ ray. d
272
320
dir_is_neg = is_dir_negative (ray. d)
273
321
274
- to_visit_offset, current_node_i = 1 , 1
275
- @inbounds nodes_to_visit = bvh. nodes_to_visit[Threads. threadid ()]
276
- nodes_to_visit .= 0
277
-
322
+ to_visit_offset, current_node_i = Int32 (1 ), Int32 (1 )
323
+ nodes_to_visit = @ntuple 64 Int32 (0 )
324
+ # nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
278
325
while true
279
326
ln = bvh. nodes[current_node_i]
280
327
if intersect_p (ln. bounds, ray, inv_dir, dir_is_neg)
281
- if ln isa LinearBVHLeaf && ln. n_primitives > 0
328
+ if ! ln . is_interior && ln. n_primitives > 0
282
329
for i in 0 : ln. n_primitives- 1
283
330
intersect_p (
284
- bvh. primitives[ln. primitives_offset + i], ray,
331
+ bvh. primitives[ln. offset + i], ray,
285
332
) && return true
286
333
end
287
334
to_visit_offset == 1 && break
288
- to_visit_offset -= 1
335
+ to_visit_offset -= Int32 ( 1 )
289
336
current_node_i = nodes_to_visit[to_visit_offset]
290
337
else
291
338
if dir_is_neg[ln. split_axis] == 2
292
- nodes_to_visit[to_visit_offset] = current_node_i + 1
293
- current_node_i = ln. second_child_offset
339
+ @setindex 64 nodes_to_visit[to_visit_offset] = Int32 (current_node_i + 1 )
340
+ # nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
341
+ current_node_i = Int32 (ln. offset)
294
342
else
295
- nodes_to_visit[to_visit_offset] = ln. second_child_offset
296
- current_node_i += 1
343
+ @setindex 64 nodes_to_visit[to_visit_offset] = Int32 (ln. offset)
344
+ # nodes_to_visit[to_visit_offset] = Int32(ln.offset)
345
+ current_node_i += Int32 (1 )
297
346
end
298
- to_visit_offset += 1
347
+ to_visit_offset += Int32 ( 1 )
299
348
end
300
349
else
301
350
to_visit_offset == 1 && break
302
- to_visit_offset -= 1
303
- current_node_i = nodes_to_visit[to_visit_offset]
351
+ to_visit_offset -= Int32 ( 1 )
352
+ current_node_i = Int32 ( nodes_to_visit[to_visit_offset])
304
353
end
305
354
end
306
355
false
0 commit comments