Skip to content

Commit 018ffe8

Browse files
committed
fix bugs + perf improvements
1 parent 107d30c commit 018ffe8

File tree

13 files changed

+220
-170
lines changed

13 files changed

+220
-170
lines changed

src/Trace.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ abstract type Material end
1919
abstract type BxDF end
2020
abstract type Integrator end
2121

22-
const Radiance = Val{:Radiance}
23-
const Importance = Val{:Importance}
24-
const TransportMode = Union{Radiance,Importance}
22+
const Radiance = UInt8(1)
23+
const Importance = UInt8(2)
2524

2625
const DO_ASSERTS = false
2726
macro real_assert(expr, msg="")

src/accel/bvh.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ struct BVHAccel <: AccelPrimitive
5151
primitives::Vector{P} where P<:Primitive
5252
max_node_primitives::UInt8
5353
nodes::Vector{LinearBVH}
54+
nodes_to_visit::Vector{Vector{Int32}}
5455

5556
function BVHAccel(
5657
primitives::Vector{P}, max_node_primitives::Integer = 1,
5758
) where P<:Primitive
5859
max_node_primitives = min(255, max_node_primitives)
5960
length(primitives) == 0 && return new(primitives, max_node_primitives)
60-
61+
nodes_to_visit = [zeros(Int32, 64) for _ in 1:Threads.maxthreadid()]
6162
primitives_info = [
6263
BVHPrimitiveInfo(i, world_bound(p))
6364
for (i, p) in enumerate(primitives)
@@ -75,7 +76,7 @@ struct BVHAccel <: AccelPrimitive
7576
_unroll(flattened, root, offset)
7677
@real_assert total_nodes[] + 1 == offset[]
7778

78-
new(ordered_primitives, max_node_primitives, flattened)
79+
new(ordered_primitives, max_node_primitives, flattened, nodes_to_visit)
7980
end
8081
end
8182

@@ -220,9 +221,9 @@ function intersect!(pool, bvh::BVHAccel, ray::MutableRef{<:AbstractRay})
220221
dir_is_neg = is_dir_negative(ray.d)
221222

222223
to_visit_offset, current_node_i = 1, 1
223-
nodes_to_visit = zeros(Int32, 64)
224-
225-
while true
224+
@inbounds nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
225+
nodes_to_visit .= 0
226+
@inbounds while true
226227
ln = bvh.nodes[current_node_i]
227228
if intersect_p(pool, ln.bounds, ray, inv_dir, dir_is_neg)
228229
if ln isa LinearBVHLeaf && ln.n_primitives > 0
@@ -268,7 +269,8 @@ function intersect_p(pool, bvh::BVHAccel, ray::MutableRef{<:AbstractRay})
268269
dir_is_neg = is_dir_negative(ray.d)
269270

270271
to_visit_offset, current_node_i = 1, 1
271-
nodes_to_visit = zeros(Int32, 64)
272+
@inbounds nodes_to_visit = bvh.nodes_to_visit[Threads.threadid()]
273+
nodes_to_visit .= 0
272274

273275
while true
274276
ln = bvh.nodes[current_node_i]

src/camera/camera.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ function generate_ray_differential(
5858
)
5959
ray_x, wt_x = generate_ray(pool, camera, shifted_x)
6060
ray_y, wt_y = generate_ray(pool, camera, shifted_y)
61-
ray = allocate(pool, RayDifferentials,
61+
rayd = allocate(pool, RayDifferentials,
6262
(ray.o, ray.d, ray.t_max, ray.time,
6363
true, ray_x.o, ray_y.o, ray_x.d, ray_y.d)
6464
)
65-
ray, wt
65+
rayd, wt
6666
end
6767

6868
include("perspective.jl")

src/integrators/sampler.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function (i::SamplerIntegrator)(scene::Scene)
2222
bar = Progress(total_tiles, 1)
2323

2424
@info "Utilizing $(Threads.nthreads()) threads"
25-
mempools = [MemoryPool(round(Int, 16384)) for _ in 1:Threads.maxthreadid()]
25+
mempools = [MemoryPool(round(Int, 2*16384)) for _ in 1:Threads.maxthreadid()]
2626
Threads.@threads for k in 0:total_tiles
2727
x, y = k % width, k ÷ width
2828
tile = Point2f(x, y)
@@ -38,17 +38,16 @@ function (i::SamplerIntegrator)(scene::Scene)
3838
for pixel in tile_bounds
3939
start_pixel!(t_sampler, pixel)
4040
while has_next_sample(t_sampler)
41-
LifeCycle(pool) do pool
42-
camera_sample = get_camera_sample(t_sampler, pixel)
43-
ray, ω = generate_ray_differential(pool, i.camera, camera_sample)
44-
scale_differentials!(ray, spp_sqr)
45-
l = RGBSpectrum(0f0)
46-
ω > 0.0f0 && (l = li(pool, i, ray, scene, 1))
47-
# TODO check l for invalid values
48-
isnan(l) && (l = RGBSpectrum(0f0))
49-
add_sample!(film_tile, camera_sample.film, l, ω)
50-
start_next_sample!(t_sampler)
51-
end
41+
free_all(pool) # clear memory pool
42+
camera_sample = get_camera_sample(t_sampler, pixel)
43+
ray, ω = generate_ray_differential(pool, i.camera, camera_sample)
44+
scale_differentials!(ray, spp_sqr)
45+
l = RGBSpectrum(0f0)
46+
ω > 0.0f0 && (l = li(pool, i, ray, scene, 1))
47+
# TODO check l for invalid values
48+
isnan(l) && (l = RGBSpectrum(0f0))
49+
add_sample!(film_tile, camera_sample.film, l, ω)
50+
start_next_sample!(t_sampler)
5251
end
5352
end
5453
merge_film_tile!(get_film(i.camera), film_tile)

src/integrators/sppm.jl

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ mutable struct SPPMPixelListNode
105105
end
106106
end
107107

108-
struct SPPMIntegrator <: Integrator
109-
camera::C where C<:Camera
108+
struct SPPMIntegrator{C<:Camera} <: Integrator
109+
camera::C
110110
initial_search_radius::Float32
111111
max_depth::Int64
112112
n_iterations::Int64
@@ -122,7 +122,7 @@ struct SPPMIntegrator <: Integrator
122122
photons_per_iteration > 0
123123
? photons_per_iteration : area(get_film(camera).crop_bounds)
124124
)
125-
new(
125+
new{C}(
126126
camera, initial_search_radius, max_depth,
127127
n_iterations, photons_per_iteration, write_frequency,
128128
)
@@ -149,16 +149,17 @@ function (i::SPPMIntegrator)(scene::Scene)
149149
n_tiles::Point2 = Int64.(floor.((pixel_extent .+ tile_size) ./ tile_size))
150150

151151
sampler = UniformSampler(1)
152+
mempools = [MemoryPool(round(Int, 100_000)) for _ in 1:Threads.maxthreadid()]
152153
for iteration in 1:i.n_iterations
153154
_generate_visible_sppm_points!(
154-
i, pixels, scene,
155+
mempools, i, pixels, scene,
155156
n_tiles, tile_size, sampler,
156157
pixel_bounds, inv_sqrt_spp,
157158
)
158159
_clean_grid!(grid)
159160
grid_bounds, grid_resolution = _populate_grid!(grid, pixels)
160161
_trace_photons!(
161-
i, scene, iteration, light_distribution,
162+
mempools, i, scene, iteration, light_distribution,
162163
grid, grid_bounds, grid_resolution,
163164
n_pixels,
164165
)
@@ -173,14 +174,16 @@ function (i::SPPMIntegrator)(scene::Scene)
173174
end
174175

175176
function _generate_visible_sppm_points!(
176-
i::SPPMIntegrator, pixels::Matrix{SPPMPixel}, scene::Scene,
177-
n_tiles::Point2, tile_size::Int64, sampler::S,
178-
pixel_bounds::Bounds2, inv_sqrt_spp::Float32,
179-
) where S<:AbstractSampler
177+
mempools, i::SPPMIntegrator, pixels::Matrix{SPPMPixel}, scene::Scene,
178+
n_tiles::Point2, tile_size::Int64, sampler::S,
179+
pixel_bounds::Bounds2, inv_sqrt_spp::Float32,
180+
) where S<:AbstractSampler
181+
180182
width, height = n_tiles
181183
total_tiles = width * height - 1
182184

183185
bar = get_progress_bar(total_tiles, "Camera pass: ")
186+
184187
Threads.@threads for k in 0:total_tiles
185188
x, y = k % width, k ÷ width
186189
tile = Point2f(x, y)
@@ -189,45 +192,47 @@ function _generate_visible_sppm_points!(
189192
tb_min = pixel_bounds.p_min .+ tile .* tile_size
190193
tb_max = min.(tb_min .+ (tile_size - 1), pixel_bounds.p_max)
191194
tile_bounds = Bounds2(tb_min, tb_max)
195+
pool = mempools[Threads.threadid()]
192196
for pixel_point in tile_bounds
197+
free_all(pool)
193198
start_pixel!(tile_sampler, pixel_point)
194199
# set_sample_number!(tile_sampler, iteration)
195200

196201
camera_sample = get_camera_sample(tile_sampler, pixel_point)
197-
ray, β = generate_ray_differential(i.camera, camera_sample)
202+
rayd::RayDifferentials, β = generate_ray_differential(pool, i.camera, camera_sample)
198203
β 0f0 && continue
199204
β = RGBSpectrum(β)
200205
@real_assert !isnan(β)
201-
scale_differentials!(ray, inv_sqrt_spp)
206+
scale_differentials!(rayd, inv_sqrt_spp)
202207
# Follow camera ray path until a visible point is created.
203208
# Get SPPMPixel for current `pixel`.
204209
pixel_point = Int64.(pixel_point)
205210
pixel = pixels[pixel_point[2], pixel_point[1]]
206211
specular_bounce = false
207212
depth = 1
208213
while depth i.max_depth
209-
hit, surface_interaction = intersect!(scene, ray)
214+
hit, primitive, surface_interaction = intersect!(pool, scene, rayd)
210215
if !hit # Accumulate light contributions to the background.
211216
for light in scene.lights
212-
pixel.Ld += β * le(light, ray)
217+
pixel.Ld += β * le(light, rayd)
213218
end
214219
break
215220
end
216221
# Process SPPM camera ray intersection.
217222
# Compute BSDF at SPPM camera ray intersection.
218-
bsdf = compute_scattering!(pool, surface_interaction, ray, true)
223+
bsdf = compute_scattering!(pool, primitive, surface_interaction, rayd, true)
219224
if bsdf nothing
220-
ray = spawn_ray(surface_interaction, ray.d)
225+
rayd = allocate(pool, RayDifferentials, spawn_ray(pool, surface_interaction, rayd.d))
221226
continue
222227
end
223228
# Accumulate direct illumination at
224229
# SPPM camera-ray intersection.
225-
wo = -ray.d
230+
wo = -rayd.d
226231
if depth == 1 || specular_bounce
227232
pixel.Ld += β * le(surface_interaction, wo)
228233
end
229234
pixel.Ld += uniform_sample_one_light(
230-
surface_interaction, scene, tile_sampler,
235+
pool, bsdf, surface_interaction, scene, tile_sampler,
231236
)
232237
# Possibly create visible point and end camera path.
233238
is_diffuse = num_components(bsdf,
@@ -261,7 +266,7 @@ function _generate_visible_sppm_points!(
261266
β /= continue_probability
262267
@real_assert !isnan(β) && !isinf(β)
263268
end
264-
ray = RayDifferentials(spawn_ray(surface_interaction, wi))
269+
rayd = allocate(pool, RayDifferentials, spawn_ray(pool, surface_interaction, wi))
265270
depth += 1
266271
end
267272
end
@@ -318,7 +323,7 @@ function _populate_grid!(
318323
end
319324

320325
function _trace_photons!(
321-
i::SPPMIntegrator, scene::Scene, iteration::Int64,
326+
mempools, i::SPPMIntegrator, scene::Scene, iteration::Int64,
322327
light_distribution::Distribution1D,
323328
grid::Vector{Maybe{SPPMPixelListNode}},
324329
grid_bounds::Bounds3, grid_resolution::Point3,
@@ -332,6 +337,7 @@ function _trace_photons!(
332337
shutter_open = i.camera.core.core.shutter_open
333338
shutter_close = i.camera.core.core.shutter_close
334339
Threads.@threads for photon_index in 0:i.photons_per_iteration-1
340+
pool = mempools[Threads.threadid()]
335341
# Follow photon path for `photon_index`.
336342
halton_index = halton_base + photon_index
337343
halton_dim = 0
@@ -359,10 +365,10 @@ function _trace_photons!(
359365
halton_dim += 5
360366
# Generate `photon_ray` from light source and initialize β.
361367
le, ray, light_normal, pdf_pos, pdf_dir = sample_le(
362-
light, u_light_0, u_light_1, u_light_time,
368+
pool, light, u_light_0, u_light_1, u_light_time,
363369
)
364370
(pdf_pos 0f0 || pdf_dir 0f0 || is_black(le)) && continue
365-
photon_ray = RayDifferentials(ray)
371+
photon_ray = allocate(pool, RayDifferentials, ray)
366372
β = abs(light_normal photon_ray.d) * le / (
367373
light_pdf * pdf_pos * pdf_dir
368374
)
@@ -372,7 +378,7 @@ function _trace_photons!(
372378
# Follow photon path through scene and record intersections.
373379
depth = 1
374380
while depth i.max_depth
375-
hit, interaction = intersect!(scene, photon_ray)
381+
hit, primitive, interaction = intersect!(pool, scene, photon_ray)
376382
!hit && break
377383
if depth > 1
378384
# Add photon contribution to nearby visible points.
@@ -385,8 +391,8 @@ function _trace_photons!(
385391
node::Maybe{SPPMPixelListNode} = grid[h]
386392
while node nothing
387393
if distance_squared(
388-
node.pixel.vp.p, interaction.core.p,
389-
) > (node.pixel.radius^2)
394+
node.pixel.vp.p, interaction.core.p,
395+
) > (node.pixel.radius^2)
390396
node = node.next
391397
continue
392398
end
@@ -403,9 +409,10 @@ function _trace_photons!(
403409
end
404410
# Sample new photon direction.
405411
# Compute BSDF at photon intersection point.
406-
compute_scattering!(pool, interaction, photon_ray, true, Importance)
407-
if interaction.bsdf nothing
408-
photon_ray = spawn_ray(interaction, photon_ray.d)
412+
bsdf = compute_scattering!(pool, primitive, interaction, photon_ray, true, Importance)
413+
414+
if bsdf nothing
415+
photon_ray = spawn_ray(pool, interaction, photon_ray.d)
409416
continue
410417
end
411418
# Sample BSDF spectrum and direction `wi` for reflected photon.
@@ -415,7 +422,7 @@ function _trace_photons!(
415422
)
416423
halton_dim += 2
417424
wi, fr, pdf, sampled_type = sample_f(
418-
interaction.bsdf, -photon_ray.d, bsdf_sample, BSDF_ALL,
425+
bsdf, -photon_ray.d, bsdf_sample, BSDF_ALL,
419426
)
420427
(is_black(fr) || pdf 0f0) && break
421428

@@ -428,7 +435,7 @@ function _trace_photons!(
428435
end
429436
halton_dim += 1
430437
# β = β_new / (1f0 - q)
431-
photon_ray = RayDifferentials(spawn_ray(interaction, wi))
438+
photon_ray = allocate(pool, RayDifferentials, spawn_ray(pool, interaction, wi))
432439
depth += 1
433440
end
434441
next!(bar)
@@ -477,8 +484,9 @@ Calculate indices of a point `p` in grid constrained by `bounds`.
477484
Computed indices are in [0, resolution), which is the correct input for `hash`.
478485
"""
479486
@inline function to_grid(
480-
p::Point3f, bounds::Bounds3, grid_resolution::Point3,
481-
)::Tuple{Bool,Point3{UInt64}}
487+
p::Point3f, bounds::Bounds3, grid_resolution::Point3,
488+
)::Tuple{Bool,Point3{UInt64}}
489+
482490
p_offset = offset(bounds, p)
483491
grid_point = Point3{Int64}(
484492
floor(grid_resolution[1] * p_offset[1]),
@@ -501,8 +509,9 @@ end
501509
end
502510

503511
function uniform_sample_one_light(
504-
i::SurfaceInteraction, scene::Scene, sampler::S,
505-
)::RGBSpectrum where S<:AbstractSampler
512+
pool, bsdf, i::SurfaceInteraction, scene::Scene, sampler::S,
513+
)::RGBSpectrum where S<:AbstractSampler
514+
506515
n_lights = length(scene.lights)
507516
n_lights == 0 && return RGBSpectrum(0f0)
508517

@@ -513,21 +522,22 @@ function uniform_sample_one_light(
513522
u_light = get_2d(sampler)
514523
u_scatter = get_2d(sampler)
515524

516-
estimate_direct(i, u_scatter, light, u_light, scene, sampler) / light_pdf
525+
estimate_direct(pool, bsdf, i, u_scatter, light, u_light, scene, sampler) / light_pdf
517526
end
518527

519528
function estimate_direct(
520-
interaction::SurfaceInteraction, u_scatter::Point2f, light::L,
521-
u_light::Point2f, scene::Scene, sampler::S, specular::Bool = false,
522-
)::RGBSpectrum where {L<:Light,S<:AbstractSampler}
529+
pool, bsdf, interaction::SurfaceInteraction, u_scatter::Point2f, light::L,
530+
u_light::Point2f, scene::Scene, sampler::S, specular::Bool = false,
531+
)::RGBSpectrum where {L<:Light,S<:AbstractSampler}
532+
523533
bsdf_flags = specular ? BSDF_ALL : (BSDF_ALL & ~BSDF_SPECULAR)
524534
Ld = RGBSpectrum(0f0)
525535
# Sample light source with multiple importance sampling.
526-
Li, wi, light_pdf, visibility = sample_li(light, interaction.core, u_light)
536+
Li, wi, light_pdf, visibility = sample_li(pool, light, interaction.core, u_light)
527537
if light_pdf > 0 && !is_black(Li)
528538
# Evaluate BSDF for light sampling strategy.
529539
f = (
530-
interaction.bsdf(interaction.core.wo, wi, bsdf_flags)
540+
bsdf(interaction.core.wo, wi, bsdf_flags)
531541
*
532542
abs(wi interaction.shading.n)
533543
)
@@ -540,7 +550,7 @@ function estimate_direct(
540550
else
541551
@real_assert false # TODO no non delta lights right now
542552
scattering_pdf = compute_pdf(
543-
interaction.bsdf, interaction.core.wo, wi, bsdf_flags,
553+
bsdf, interaction.core.wo, wi, bsdf_flags,
544554
)
545555
weight = power_heuristic(1, light_pdf, 1, scattering_pdf)
546556
Ld += f * Li * weight / light_pdf

src/materials/bsdf.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ function (b::BSDF)(
9191

9292
output = RGBSpectrum(0f0)
9393
bxdfs = b.bxdfs
94+
bxdfs.last == 0 && return output
9495
Base.Cartesian.@nexprs 8 i -> begin
9596
@assert i <= bxdfs.last
9697
bxdf = bxdfs[i]

0 commit comments

Comments
 (0)