Skip to content

Commit d760f40

Browse files
committed
make intersect!(bvh, ray) work on the GPU
1 parent a175b15 commit d760f40

File tree

14 files changed

+568
-338
lines changed

14 files changed

+568
-338
lines changed

docs/code/basic-scene.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,13 @@ using Trace
33
using FileIO
44
using ImageCore
55
using BenchmarkTools
6-
using Makie, FileIO, ImageShow
7-
8-
9-
catmesh = load(Makie.assetpath("cat.obj"))
10-
img = load(Makie.assetpath("diffusemap.png"))
11-
m = normal_mesh(Tesselation(Sphere(Point3f(0), 1), 32))
6+
using FileIO, ImageShow
127

138
function tmesh(prim, material)
149
prim = prim isa Sphere ? Tesselation(prim, 64) : prim
1510
mesh = normal_mesh(prim)
16-
triangles = Trace.create_triangle_mesh(mesh, Trace.ShapeCore())
17-
return [Trace.GeometricPrimitive(t, material) for t in triangles]
11+
m = Trace.create_triangle_mesh(mesh, Trace.ShapeCore())
12+
return Trace.GeometricPrimitive(m, material)
1813
end
1914

2015
material_red = Trace.MatteMaterial(
@@ -52,15 +47,15 @@ begin
5247
l = tmesh(Rect3f(Vec3f(-2, -5, 0), Vec3f(0.01, 10, 10)), material_red)
5348
r = tmesh(Rect3f(Vec3f(2, -5, 0), Vec3f(0.01, 10, 10)), material_blue)
5449

55-
bvh = Trace.BVHAccel([s1..., s2..., s3..., s4..., ground..., back..., l..., r...], 1);
50+
bvh = Trace.BVHAccel([s1, s2, s3, s4, ground, back, l, r], 1);
5651

5752
lights = [
5853
# Trace.PointLight(Vec3f(0, -1, 2), Trace.RGBSpectrum(22.0f0)),
5954
Trace.PointLight(Vec3f(0, 0, 2), Trace.RGBSpectrum(10.0f0)),
6055
Trace.PointLight(Vec3f(0, 3, 3), Trace.RGBSpectrum(15.0f0)),
6156
]
6257
scene = Trace.Scene(lights, bvh);
63-
resolution = Point2f(10)
58+
resolution = Point2f(1024)
6459
f = Trace.LanczosSincFilter(Point2f(1.0f0), 3.0f0)
6560
film = Trace.Film(resolution,
6661
Trace.Bounds2(Point2f(0.0f0), Point2f(1.0f0)),
@@ -72,15 +67,38 @@ begin
7267
Trace.look_at(Point3f(0, 4, 2), Point3f(0, -4, -1), Vec3f(0, 0, 1)),
7368
screen_window, 0.0f0, 1.0f0, 0.0f0, 1.0f6, 45.0f0, film,
7469
)
75-
7670
end
71+
7772
begin
78-
integrator = Trace.WhittedIntegrator(cam, Trace.UniformSampler(8), 1)
79-
@time integrator(scene)
73+
integrator = Trace.WhittedIntegrator(cam, Trace.UniformSampler(8), 5)
74+
@time integrator(scene, film)
8075
img = reverse(film.framebuffer, dims=1)
8176
end
77+
# 6.7s
78+
79+
80+
camera_sample = Trace.get_camera_sample(integrator.sampler, Point2f(512))
81+
ray, ω = Trace.generate_ray_differential(integrator.camera, camera_sample)
82+
83+
@btime Trace.intersect_p(bvh, ray)
84+
@btime Trace.intersect!(bvh, ray)
85+
86+
###
87+
# Int32 always
88+
# 42.000 μs (1 allocation: 624 bytes)
89+
# Tuple instead of vector for nodes_to_visit
90+
# 43.400 μs (1 allocation: 624 bytes)
91+
# AFTER GPU rework
92+
# intersect!
93+
# 40.500 μs (1 allocation: 368 bytes)
94+
# intersect_p
95+
# 11.500 μs (0 allocations: 0 bytes)
96+
97+
### LinearBVHLeaf as one type
98+
# 5.247460 seconds (17.55 k allocations: 19.783 MiB, 46 lock conflicts)
99+
82100
# begin
83-
# integrator = Trace.SPPMIntegrator(cam, 0.075f0, 5, 100)
101+
# integrator = Trace.SPPMIntegrator(cam, 0.075f0, 5, 1)
84102
# integrator(scene)
85103
# img = reverse(film.framebuffer, dims=1)
86104
# end

src/Trace.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ end
192192
@inline function intersect!(scene::Scene, ray::AbstractRay)
193193
intersect!(scene.aggregate, ray)
194194
end
195+
195196
@inline function intersect_p(scene::Scene, ray::AbstractRay)
196197
intersect_p(scene.aggregate, ray)
197198
end

src/accel/bvh.jl

Lines changed: 118 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -35,61 +35,83 @@ struct BVHNode
3535
end
3636

3737
abstract type LinearNode end
38-
struct LinearBVHLeaf <: LinearNode
38+
39+
struct LinearBVH <: LinearNode
3940
bounds::Bounds3
40-
primitives_offset::UInt32
41+
offset::UInt32
4142
n_primitives::UInt32
42-
end
43-
struct LinearBVHInterior <: LinearNode
44-
bounds::Bounds3
45-
second_child_offset::UInt32
4643
split_axis::UInt8
44+
is_interior::Bool
4745
end
48-
const LinearBVH = Union{LinearBVHLeaf,LinearBVHInterior}
4946

50-
struct BVHAccel{P <: Primitive} <: AccelPrimitive
51-
primitives::Vector{P}
52-
max_node_primitives::UInt8
53-
nodes::Vector{LinearBVH}
54-
nodes_to_visit::Vector{Vector{Int32}}
47+
function LinearBVHLeaf(bounds::Bounds3, primitives_offset::Integer, n_primitives::Integer)
48+
LinearBVH(bounds, primitives_offset, n_primitives, 0, false)
49+
end
50+
function LinearBVHInterior(bounds::Bounds3, second_child_offset::Integer, split_axis::Integer)
51+
LinearBVH(bounds, second_child_offset, 0, split_axis, true)
52+
end
5553

56-
function BVHAccel(
57-
primitives::Vector{P}, max_node_primitives::Integer = 1,
58-
) where P<:Primitive
59-
max_node_primitives = min(255, max_node_primitives)
60-
isempty(primitives) && return new{P}(primitives, max_node_primitives)
61-
nodes_to_visit = [zeros(Int32, 64) for _ in 1:Threads.maxthreadid()]
62-
primitives_info = [
63-
BVHPrimitiveInfo(i, world_bound(p))
64-
for (i, p) in enumerate(primitives)
65-
]
54+
function primitives_to_bvh(primitives, max_node_primitives=1)
55+
max_node_primitives = min(255, max_node_primitives)
56+
isempty(primitives) && return (primitives, max_node_primitives, LinearBVH[])
57+
primitives_info = [
58+
BVHPrimitiveInfo(i, world_bound(p))
59+
for (i, p) in enumerate(primitives)
60+
]
61+
total_nodes = Ref(0)
62+
ordered_primitives = similar(primitives, 0)
63+
root = _init(
64+
primitives, primitives_info, 1, length(primitives),
65+
total_nodes, ordered_primitives, max_node_primitives,
66+
)
6667

67-
total_nodes = Ref(0)
68-
ordered_primitives = P[]
69-
root = _init(
70-
primitives, primitives_info, 1, length(primitives),
71-
total_nodes, ordered_primitives, max_node_primitives,
72-
)
68+
offset = Ref{UInt32}(1)
69+
flattened = Vector{LinearBVH}(undef, total_nodes[])
70+
_unroll(flattened, root, offset)
71+
@real_assert total_nodes[] + 1 == offset[]
72+
return (ordered_primitives, max_node_primitives, flattened)
73+
end
7374

74-
offset = Ref{UInt32}(1)
75-
flattened = Vector{LinearBVH}(undef, total_nodes[])
76-
_unroll(flattened, root, offset)
77-
@real_assert total_nodes[] + 1 == offset[]
75+
struct BVHAccel{
76+
PVec <:AbstractVector,
77+
MatVec <: AbstractVector{<:Material},
78+
NodeVec <: AbstractVector{LinearBVH}
79+
} <: AccelPrimitive
80+
primitives::PVec
81+
materials::MatVec
82+
max_node_primitives::UInt8
83+
nodes::NodeVec
84+
end
7885

79-
new{P}(ordered_primitives, max_node_primitives, flattened, nodes_to_visit)
86+
function BVHAccel(
87+
primitives::AbstractVector{P}, max_node_primitives::Integer=1,
88+
) where {P}
89+
materials = map(x-> x.material, primitives)
90+
meshes = map(x-> x.shape, primitives)
91+
triangles = Triangle[]
92+
for (mi, m) in enumerate(meshes)
93+
vertices = m.vertices
94+
for i in 1:div(length(m.indices), 3)
95+
push!(triangles, Triangle(m, i, mi))
96+
end
8097
end
98+
ordered_primitives, max_prim, nodes = primitives_to_bvh(triangles, max_node_primitives)
99+
return BVHAccel(ordered_primitives, materials, UInt8(max_prim), nodes)
81100
end
82101

102+
103+
83104
mutable struct BucketInfo
84105
count::UInt32
85106
bounds::Bounds3
86107
end
87108

88109
function _init(
89-
primitives::Vector{P}, primitives_info::Vector{BVHPrimitiveInfo},
90-
from::Integer, to::Integer, total_nodes::Ref{Int64},
91-
ordered_primitives::Vector{P}, max_node_primitives::Integer,
92-
) where P<:Primitive
110+
primitives::AbstractVector, primitives_info::Vector{BVHPrimitiveInfo},
111+
from::Integer, to::Integer, total_nodes::Ref{Int64},
112+
ordered_primitives::AbstractVector, max_node_primitives::Integer,
113+
)
114+
93115
total_nodes[] += 1
94116
n_primitives = to - from + 1
95117
# Compute bounds for all primitives in BVH node.
@@ -186,8 +208,9 @@ function _init(
186208
end
187209

188210
function _unroll(
189-
linear_nodes::Vector{LinearBVH}, node::BVHNode, offset::Ref{UInt32},
190-
)
211+
linear_nodes::Vector{LinearBVH}, node::BVHNode, offset::Ref{UInt32},
212+
)
213+
191214
l_offset = offset[]
192215
offset[] += 1
193216

@@ -210,29 +233,51 @@ end
210233
length(bvh.nodes) > 0 ? bvh.nodes[1].bounds : Bounds3()
211234
end
212235

213-
function intersect!(bvh::BVHAccel{P}, ray::AbstractRay)::Tuple{Bool,P,SurfaceInteraction} where {P}
236+
macro ntuple(N, value)
237+
expr = :(())
238+
for i in 1:N
239+
push!(expr.args, :($(esc(value))))
240+
end
241+
return expr
242+
end
243+
244+
macro setindex(N, setindex_expr)
245+
@assert Meta.isexpr(setindex_expr, :(=))
246+
index_expr = setindex_expr.args[1]
247+
@assert Meta.isexpr(index_expr, :ref)
248+
tuple = index_expr.args[1]
249+
idx = index_expr.args[2]
250+
value = setindex_expr.args[2]
251+
expr = :(())
252+
for i in 1:N
253+
push!(expr.args, :(ifelse($i != $(esc(idx)), $(esc(tuple))[$i], $(esc(value)))))
254+
end
255+
return :($(esc(tuple)) = $expr)
256+
end
257+
258+
@inline function intersect!(bvh::BVHAccel{P}, ray::AbstractRay) where {P}
214259
hit = false
215260
interaction = SurfaceInteraction()
216-
isempty(bvh.nodes) && return hit, nothing, interaction
217261

218262
ray = check_direction(ray)
219263
inv_dir = 1f0 ./ ray.d
220264
dir_is_neg = is_dir_negative(ray.d)
221265

222-
to_visit_offset::Int32, current_node_i::Int32 = 1, 1
223-
@inbounds nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
224-
@inbounds for i in eachindex(nodes_to_visit)
225-
nodes_to_visit[i] = Int32(0)
226-
end
227-
primitive::P = first(bvh.primitives)
228-
primitives = bvh.primitives::Vector{P}
266+
to_visit_offset, current_node_i = Int32(1), Int32(1)
267+
# Tuple version is 2us slower, which makes the total rendering time go from 5s to 7s -.-s
268+
# no other way to do this on the GPU though, is there?
269+
nodes_to_visit = @ntuple 64 Int32(0)
270+
# nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
271+
primitives = bvh.primitives
272+
primitive = first(primitives)
273+
nodes = bvh.nodes
229274
@inbounds while true
230-
ln = bvh.nodes[current_node_i]
275+
ln = nodes[current_node_i]
231276
if intersect_p(ln.bounds, ray, inv_dir, dir_is_neg)
232-
if ln isa LinearBVHLeaf && ln.n_primitives > 0
277+
if !(ln.is_interior) && ln.n_primitives > 0
233278
# Intersect ray with primitives in node.
234279
for i in 0:ln.n_primitives-1
235-
tmp_primitive::P = primitives[ln.primitives_offset+i]
280+
tmp_primitive = primitives[ln.offset+i]
236281
tmp_hit, ray, tmp_interaction = intersect_p!(
237282
tmp_primitive, ray,
238283
)
@@ -247,10 +292,12 @@ function intersect!(bvh::BVHAccel{P}, ray::AbstractRay)::Tuple{Bool,P,SurfaceInt
247292
current_node_i = nodes_to_visit[to_visit_offset]
248293
else
249294
if dir_is_neg[ln.split_axis] == 2
250-
nodes_to_visit[to_visit_offset] = current_node_i + Int32(1)
251-
current_node_i = ln.second_child_offset
295+
@setindex 64 nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
296+
# nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
297+
current_node_i = Int32(ln.offset)
252298
else
253-
nodes_to_visit[to_visit_offset] = ln.second_child_offset
299+
@setindex 64 nodes_to_visit[to_visit_offset] = Int32(ln.offset)
300+
# nodes_to_visit[to_visit_offset] = Int32(ln.offset)
254301
current_node_i += Int32(1)
255302
end
256303
to_visit_offset += Int32(1)
@@ -264,43 +311,45 @@ function intersect!(bvh::BVHAccel{P}, ray::AbstractRay)::Tuple{Bool,P,SurfaceInt
264311
return hit, primitive, interaction
265312
end
266313

267-
function intersect_p(bvh::BVHAccel, ray::AbstractRay)
314+
@inline function intersect_p(bvh::BVHAccel, ray::AbstractRay)
315+
268316
length(bvh.nodes) == 0 && return false
269317

270318
ray = check_direction(ray)
271319
inv_dir = 1f0 ./ ray.d
272320
dir_is_neg = is_dir_negative(ray.d)
273321

274-
to_visit_offset, current_node_i = 1, 1
275-
@inbounds nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
276-
nodes_to_visit .= 0
277-
322+
to_visit_offset, current_node_i = Int32(1), Int32(1)
323+
nodes_to_visit = @ntuple 64 Int32(0)
324+
# nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
278325
while true
279326
ln = bvh.nodes[current_node_i]
280327
if intersect_p(ln.bounds, ray, inv_dir, dir_is_neg)
281-
if ln isa LinearBVHLeaf && ln.n_primitives > 0
328+
if !ln.is_interior && ln.n_primitives > 0
282329
for i in 0:ln.n_primitives-1
283330
intersect_p(
284-
bvh.primitives[ln.primitives_offset+i], ray,
331+
bvh.primitives[ln.offset+i], ray,
285332
) && return true
286333
end
287334
to_visit_offset == 1 && break
288-
to_visit_offset -= 1
335+
to_visit_offset -= Int32(1)
289336
current_node_i = nodes_to_visit[to_visit_offset]
290337
else
291338
if dir_is_neg[ln.split_axis] == 2
292-
nodes_to_visit[to_visit_offset] = current_node_i + 1
293-
current_node_i = ln.second_child_offset
339+
@setindex 64 nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
340+
# nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
341+
current_node_i = Int32(ln.offset)
294342
else
295-
nodes_to_visit[to_visit_offset] = ln.second_child_offset
296-
current_node_i += 1
343+
@setindex 64 nodes_to_visit[to_visit_offset] = Int32(ln.offset)
344+
# nodes_to_visit[to_visit_offset] = Int32(ln.offset)
345+
current_node_i += Int32(1)
297346
end
298-
to_visit_offset += 1
347+
to_visit_offset += Int32(1)
299348
end
300349
else
301350
to_visit_offset == 1 && break
302-
to_visit_offset -= 1
303-
current_node_i = nodes_to_visit[to_visit_offset]
351+
to_visit_offset -= Int32(1)
352+
current_node_i = Int32(nodes_to_visit[to_visit_offset])
304353
end
305354
end
306355
false

src/bounds.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525
function Base.getindex(b::Union{Bounds2,Bounds3}, i::Integer)
2626
i == 1 && return b.p_min
2727
i == 2 && return b.p_max
28-
error("Invalid index `$i`. Only `1` & `2` are valid.")
28+
# error("Invalid index `$i`. Only `1` & `2` are valid.")
2929
end
3030
function is_valid(b::Bounds3)::Bool
3131
all(b.p_min .!= Inf32) && all(b.p_max .!= -Inf32)

src/camera/camera.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ struct CameraCore
44
camera_to_world::Transformation
55
shutter_open::Float32
66
shutter_close::Float32
7-
film::Film
8-
# medium::Medium
97
end
108

119
struct CameraSample

src/camera/perspective.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct ProjectiveCamera <: Camera
1515
lens_radius::Float32, focal_distance::Float32,
1616
film::Film,
1717
)
18-
core = CameraCore(camera_to_world, shutter_open, shutter_close, film)
18+
core = CameraCore(camera_to_world, shutter_open, shutter_close)
1919
# Computer projective camera transformations.
2020
resolution = scale(film.resolution..., 1)
2121
window_width = screen_window.p_max .- screen_window.p_min

0 commit comments

Comments
 (0)