diff --git a/LocalPreferences.toml b/LocalPreferences.toml deleted file mode 100644 index f6e7c7e..0000000 --- a/LocalPreferences.toml +++ /dev/null @@ -1,9 +0,0 @@ -[AMDGPU] -soft_memory_limit = "80 %" -hard_memory_limit = "80 %" - -[Flux] -gpu_backend = "AMDGPU" - -[GaussianSplatting] -gpu_backend = "AMDGPU" diff --git a/Project.toml b/Project.toml index 2c955a8..5c60c8e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,13 +6,11 @@ authors = ["Anton Smirnov "] [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CImGui = "5d785b6c-b76f-510e-a07c-3070796c7e87" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GLFW = "f7f18e0c-5ee9-5ccd-a5bf-e8befd85ed98" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -29,7 +27,6 @@ NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" NerfUtils = "99c1d5ce-7c61-4a25-a107-a5ade2e2a8e4" NeuralGraphicsGL = "263f7e6d-e369-49af-a86e-c85638573b76" PlyIO = "42171d58-473b-503a-8d5f-782019eb09ec" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" @@ -42,23 +39,20 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] GaussianSplattingAMDGPUExt = "AMDGPU" -GaussianSplattingCUDAExt = ["CUDA", "cuDNN"] +GaussianSplattingCUDAExt = "CUDA" [compat] AMDGPU = "2" Adapt = "4" BSON = "0.3" -BenchmarkTools = "1" CImGui = "6" CUDA = "5" ChainRulesCore = "1" Distributions = "0.25" FileIO = "1.16" -Flux = "0.16.2" GLFW = "3.4" GPUArrays = "11.2.1" GPUArraysCore = "0.2" @@ -74,12 +68,10 @@ NearestNeighbors = "0.4" NerfUtils = "0.2" NeuralGraphicsGL = "0.5" PlyIO = "1.2" -Preferences = "1.4" Quaternions = "0.7" Rotations = "1.7" SIMD = "3.6" StaticArrays = "1.9" VideoIO = "1.1" Zygote = "=0.7.3" -cuDNN = "1.3" julia = "1.10" diff --git a/docs/src/index.md b/docs/src/index.md index 03769f4..ea9269f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -25,17 +25,14 @@ GaussianSplatting.jl comes with a GUI application to train & view the gaussians. 1. Add necessary packages: ```julia ] add AMDGPU # for AMD GPU - ] add CUDA, cuDNN # for Nvidia GPU - ] add Flux + ] add CUDA # for Nvidia GPU ``` 2. Run: ```julia - julia> using AMDGPU # for AMD GPU - julia> using CUDA, cuDNN # for Nvidia GPU - julia> using Flux, GaussianSplatting - - julia> GaussianSplatting.gui("path-to-colmap-dataset-directory"; scale=1) + julia> using AMDGPU; kab = ROCBackend() # for AMD GPU + julia> using CUDA; kab = CUDABackend() # for Nvidia GPU + julia> GaussianSplatting.gui(kab, "path-to-colmap-dataset-directory"; scale=1) ``` ## Viewer mode @@ -44,27 +41,12 @@ Once you've trained a model and saved it to `.bson` file you can open it in a viewer-only mode by providing its path. ```julia -julia> GaussianSplatting.gui("path-to-checkpoint.bson") +julia> GaussianSplatting.gui(kab, "path-to-checkpoint.bson") ``` Alternative, you can load a checkpoint in a training mode (see **Usage** section) using "Save/Load" tab. -## GPU selection - -This is required only the first time per the environment. -After selecting GPU backend, restart Julia REPL. - -- AMD GPU: - ```julia - julia> Flux.gpu_backend!("AMDGPU") - ``` - -- Nvidia GPU: - ```julia - julia> Flux.gpu_backend!("CUDA") - ``` - ## References - 3D Gaussian Splatting for Real-Time Radiance Field Rendering: diff --git a/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl b/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl index ee0761e..035bebf 100644 --- a/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl +++ b/ext/GaussianSplattingCUDAExt/GaussianSplattingCUDAExt.jl @@ -2,7 +2,6 @@ module GaussianSplattingCUDAExt # using Adapt using CUDA -# using cuDNN # using KernelAbstractions using GaussianSplatting # using PrecompileTools diff --git a/src/GaussianSplatting.jl b/src/GaussianSplatting.jl index c2e3c16..0c3a345 100644 --- a/src/GaussianSplatting.jl +++ b/src/GaussianSplatting.jl @@ -17,7 +17,6 @@ using Random using Rotations using StaticArrays using Statistics -using Preferences using ImageCore using ImageIO using ImageTransformations @@ -33,8 +32,8 @@ using GLFW import CImGui.lib as iglib import BSON +import ChainRulesCore as CRC import NNlib -import Flux import ImageFiltering import KernelAbstractions as KA import NerfUtils as NU @@ -56,7 +55,7 @@ _as_T(T, x) = reinterpret(T, reshape(x, :)) include("simd.jl") include("utils.jl") -include("metrics.jl") +include("fused_ssim.jl") include("camera.jl") include("camera_opt.jl") include("dataset.jl") @@ -67,9 +66,6 @@ include("rasterization/rasterizer.jl") include("training.jl") include("gui/gui.jl") -# Hacky way to get KA.Backend. -gpu_backend() = get_backend(Flux.gpu(Array{Int}(undef, 0))) - base_array_type(backend) = error("Not implemented for backend: `$backend`.") allocate_pinned(kab, T, shape) = error("Pinned memory not supported for `$kab`.") @@ -78,8 +74,7 @@ unpin_memory(x) = error("Unpinning memory is not supported for `$(typeof(x))`.") use_ak(kab) = false -function main(dataset_path::String; scale::Int, save_path::Maybe{String} = nothing) - kab = gpu_backend() +function main(kab, dataset_path::String; scale::Int, save_path::Maybe{String} = nothing) @info "Using `$kab` GPU backend." dataset = ColmapDataset(kab, dataset_path; @@ -151,7 +146,7 @@ function main(dataset_path::String; scale::Int, save_path::Maybe{String} = nothi return end -function gui(path::String; scale::Maybe{Int} = nothing, fullscreen::Bool = false) +function gui(kab, path::String; scale::Maybe{Int} = nothing, fullscreen::Bool = false) ispath(path) || error("Path does not exist: `$path`.") viewer_mode = endswith(path, ".bson") || endswith(path, ".ply") @@ -166,7 +161,6 @@ function gui(path::String; scale::Maybe{Int} = nothing, fullscreen::Bool = false (1024, 1024, true) gui = if viewer_mode - kab = gpu_backend() if endswith(path, ".bson") θ = BSON.load(path) gaussians = GaussianModel(kab) @@ -179,9 +173,9 @@ function gui(path::String; scale::Maybe{Int} = nothing, fullscreen::Bool = false camera = Camera(; fx=fov, fy=fov, width, height=width) end - GSGUI(gaussians, camera; width, height, fullscreen, resizable) + GSGUI(kab, gaussians, camera; width, height, fullscreen, resizable) else - GSGUI(path, scale; width, height, fullscreen, resizable) + GSGUI(kab, path, scale; width, height, fullscreen, resizable) end gui |> launch! return diff --git a/src/fused_ssim.jl b/src/fused_ssim.jl new file mode 100644 index 0000000..993143e --- /dev/null +++ b/src/fused_ssim.jl @@ -0,0 +1,424 @@ +const BLOCK_X = 16 +const BLOCK_Y = 16 +const HALO = 5 + +const SHARED_X = BLOCK_X + 2 * HALO # 26 +const SHARED_Y = BLOCK_Y + 2 * HALO # 26 + +const CONV_X = BLOCK_X # 16 +const CONV_Y = SHARED_Y # 26 + +# Pre-computed 11-element Gaussian kernel. +const GAUSS = ( + 0.001028380123898387f0, + 0.0075987582094967365f0, + 0.036000773310661316f0, + 0.10936068743467331f0, + 0.21300552785396576f0, + 0.26601171493530273f0, + 0.21300552785396576f0, + 0.10936068743467331f0, + 0.036000773310661316f0, + 0.0075987582094967365f0, + 0.001028380123898387f0, +) + +# Safe pixel fetch with zero padding for out-of-bounds access. +@inline function get_pix_value(img, b::Int, c::Int, y::Int, x::Int) + W, H = size(img, 1), size(img, 2) + (x < 1 || x > W || y < 1 || y > H) && return 0f0 + return @inbounds img[x, y, c, b] +end + +# Forward kernel: Fused SSIM computation. +@kernel cpu=false unsafe_indices=true inbounds=true function _fused_ssim!( + ssim_map, dm_dmu1, dm_dsigma1_sq, dm_dsigma12, + @Const(img), @Const(ref), + C1::Float32, C2::Float32, train::Bool, +) + bx, by, bz = @index(Group, NTuple) + tx, ty = @index(Local, NTuple) + + W, H, CH, B = size(img) + pix_x = (bx - 1) * BLOCK_X + tx + pix_y = (by - 1) * BLOCK_Y + ty + + # Shared memory for the tile (img, ref). + sTile = @localmem Float32 (2, SHARED_X, SHARED_Y) + # After horizontal pass: (sumX, sumX², sumY, sumY², sumXY). + xconv = @localmem Float32 (5, CONV_X, CONV_Y) + + # Loop over channels. + for c in 1:CH + # 1) Load (img, ref) tile + halo into shared memory. + tile_size = SHARED_Y * SHARED_X + threads = BLOCK_X * BLOCK_Y + steps = cld(tile_size, threads) + + tile_start_y = (by - 1) * BLOCK_Y + 1 + tile_start_x = (bx - 1) * BLOCK_X + 1 + + tid = (ty - 1) * BLOCK_X + tx # 1-based thread rank. + + for s in 0:(steps - 1) + flat_id = s * threads + tid + if flat_id ≤ tile_size + local_y = cld(flat_id, SHARED_X) + local_x = mod1(flat_id, SHARED_X) + + # Global coordinates with halo offset. + gy = tile_start_y + local_y - 1 - HALO + gx = tile_start_x + local_x - 1 - HALO + + X = get_pix_value(img, bz, c, gy, gx) + Y = get_pix_value(ref, bz, c, gy, gx) + + sTile[1, local_x, local_y] = X + sTile[2, local_x, local_y] = Y + end + end + @synchronize + + # 2) Horizontal convolution (11×1) in shared memory. + ly = ty + lx = tx + HALO # Skip left halo. + + sumX = 0f0 + sumX2 = 0f0 + sumY = 0f0 + sumY2 = 0f0 + sumXY = 0f0 + + # Symmetric pairs around center. + @unroll for d in 1:HALO + w = GAUSS[HALO + 1 - d] + Xleft = sTile[1, lx - d, ly] + Yleft = sTile[2, lx - d, ly] + Xright = sTile[1, lx + d, ly] + Yright = sTile[2, lx + d, ly] + + sumX += (Xleft + Xright) * w + sumX2 += (Xleft * Xleft + Xright * Xright) * w + sumY += (Yleft + Yright) * w + sumY2 += (Yleft * Yleft + Yright * Yright) * w + sumXY += (Xleft * Yleft + Xright * Yright) * w + end + + # Center. + centerX = sTile[1, lx, ly] + centerY = sTile[2, lx, ly] + wc = GAUSS[HALO + 1] + sumX += centerX * wc + sumX2 += centerX * centerX * wc + sumY += centerY * wc + sumY2 += centerY * centerY * wc + sumXY += centerX * centerY * wc + + # Write partial sums. + xconv[1, tx, ly] = sumX + xconv[2, tx, ly] = sumX2 + xconv[3, tx, ly] = sumY + xconv[4, tx, ly] = sumY2 + xconv[5, tx, ly] = sumXY + + # Handle second row (threads handle 2 rows to cover CONV_Y = 26). + ly2 = ly + BLOCK_Y + if ly2 ≤ CONV_Y + sumX = 0f0 + sumX2 = 0f0 + sumY = 0f0 + sumY2 = 0f0 + sumXY = 0f0 + + @unroll for d in 1:HALO + w = GAUSS[HALO + 1 - d] + Xleft = sTile[1, lx - d, ly2] + Yleft = sTile[2, lx - d, ly2] + Xright = sTile[1, lx + d, ly2] + Yright = sTile[2, lx + d, ly2] + + sumX += (Xleft + Xright) * w + sumX2 += (Xleft * Xleft + Xright * Xright) * w + sumY += (Yleft + Yright) * w + sumY2 += (Yleft * Yleft + Yright * Yright) * w + sumXY += (Xleft * Yleft + Xright * Yright) * w + end + + cx = sTile[1, lx, ly2] + cy = sTile[2, lx, ly2] + sumX += cx * wc + sumX2 += cx * cx * wc + sumY += cy * wc + sumY2 += cy * cy * wc + sumXY += cx * cy * wc + + xconv[1, tx, ly2] = sumX + xconv[2, tx, ly2] = sumX2 + xconv[3, tx, ly2] = sumY + xconv[4, tx, ly2] = sumY2 + xconv[5, tx, ly2] = sumXY + end + @synchronize + + # 3) Vertical convolution (1×11) + final SSIM. + ly_v = ty + HALO + lx_v = tx + + out0 = 0f0 + out1 = 0f0 + out2 = 0f0 + out3 = 0f0 + out4 = 0f0 + + @unroll for d in 1:HALO + w = GAUSS[HALO + 1 - d] + top0 = xconv[1, lx_v, ly_v - d] + top1 = xconv[2, lx_v, ly_v - d] + top2 = xconv[3, lx_v, ly_v - d] + top3 = xconv[4, lx_v, ly_v - d] + top4 = xconv[5, lx_v, ly_v - d] + + bot0 = xconv[1, lx_v, ly_v + d] + bot1 = xconv[2, lx_v, ly_v + d] + bot2 = xconv[3, lx_v, ly_v + d] + bot3 = xconv[4, lx_v, ly_v + d] + bot4 = xconv[5, lx_v, ly_v + d] + + out0 += (top0 + bot0) * w + out1 += (top1 + bot1) * w + out2 += (top2 + bot2) * w + out3 += (top3 + bot3) * w + out4 += (top4 + bot4) * w + end + + # Center. + wC = GAUSS[HALO + 1] + out0 += xconv[1, lx_v, ly_v] * wC + out1 += xconv[2, lx_v, ly_v] * wC + out2 += xconv[3, lx_v, ly_v] * wC + out3 += xconv[4, lx_v, ly_v] * wC + out4 += xconv[5, lx_v, ly_v] * wC + + if pix_x ≤ W && pix_y ≤ H + mu1 = out0 + mu2 = out2 + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + + sigma1_sq = out1 - mu1_sq + sigma2_sq = out3 - mu2_sq + sigma12 = out4 - mu1 * mu2 + + A = mu1_sq + mu2_sq + C1 + B_val = sigma1_sq + sigma2_sq + C2 + C_val = 2f0 * mu1 * mu2 + C1 + D_val = 2f0 * sigma12 + C2 + + val = (C_val * D_val) / (A * B_val) + ssim_map[pix_x, pix_y, c, bz] = val + + if train + # Partial derivatives for backpropagation. + d_m_dmu1 = ( + (mu2 * 2f0 * D_val) / (A * B_val) - + (mu2 * 2f0 * C_val) / (A * B_val) - + (mu1 * 2f0 * C_val * D_val) / (A * A * B_val) + + (mu1 * 2f0 * C_val * D_val) / (A * B_val * B_val) + ) + d_m_dsigma1_sq = (-C_val * D_val) / (A * B_val * B_val) + d_m_dsigma12 = (2f0 * C_val) / (A * B_val) + + dm_dmu1[pix_x, pix_y, c, bz] = d_m_dmu1 + dm_dsigma1_sq[pix_x, pix_y, c, bz] = d_m_dsigma1_sq + dm_dsigma12[pix_x, pix_y, c, bz] = d_m_dsigma12 + end + end + @synchronize + end +end + +# Backward kernel: Compute dL/d(img) from partial derivatives and dL/dmap. +@kernel cpu=false unsafe_indices=true inbounds=true function _fused_ssim_bwd!( + dL_dimg, + @Const(img), @Const(ref), @Const(dL_dmap), + @Const(dm_dmu1), @Const(dm_dsigma1_sq), @Const(dm_dsigma12), +) + W, H, CH, B = size(img) + + bx, by, bz = @index(Group, NTuple) + tx, ty = @index(Local, NTuple) + + pix_x = (bx - 1) * BLOCK_X + tx + pix_y = (by - 1) * BLOCK_Y + ty + + # Shared memory for fused data: dm_dmu1*dL, dm_dsigma1_sq*dL, dm_dsigma12*dL. + sData = @localmem Float32 (3, SHARED_X, SHARED_Y) + sScratch = @localmem Float32 (3, CONV_X, CONV_Y) + + for c in 1:CH + p1 = 0f0 + p2 = 0f0 + if pix_x ≤ W && pix_y ≤ H + p1 = get_pix_value(img, bz, c, pix_y, pix_x) + p2 = get_pix_value(ref, bz, c, pix_y, pix_x) + end + + # 1) Load + fuse multiplication. + start_y = (by - 1) * BLOCK_Y + 1 + start_x = (bx - 1) * BLOCK_X + 1 + + tile_size = SHARED_Y * SHARED_X + threads = BLOCK_X * BLOCK_Y + steps = cld(tile_size, threads) + tid = (ty - 1) * BLOCK_X + tx + + for s in 0:(steps - 1) + flat_id = s * threads + tid + if flat_id ≤ tile_size + row = cld(flat_id, SHARED_X) + col = mod1(flat_id, SHARED_X) + + gy = start_y + row - 1 - HALO + gx = start_x + col - 1 - HALO + + chain = get_pix_value(dL_dmap, bz, c, gy, gx) + vmu = get_pix_value(dm_dmu1, bz, c, gy, gx) + vs1 = get_pix_value(dm_dsigma1_sq, bz, c, gy, gx) + vs12 = get_pix_value(dm_dsigma12, bz, c, gy, gx) + + sData[1, col, row] = vmu * chain + sData[2, col, row] = vs1 * chain + sData[3, col, row] = vs12 * chain + end + end + @synchronize + + # 2) Horizontal pass. + ly = ty + lx = tx + HALO + + for pass in 0:1 + yy = ly + pass * BLOCK_Y + if yy ≤ CONV_Y + accum0 = 0f0 + accum1 = 0f0 + accum2 = 0f0 + + @unroll for d in 1:HALO + w = GAUSS[HALO + 1 - d] + left0 = sData[1, lx - d, yy] + left1 = sData[2, lx - d, yy] + left2 = sData[3, lx - d, yy] + + right0 = sData[1, lx + d, yy] + right1 = sData[2, lx + d, yy] + right2 = sData[3, lx + d, yy] + + accum0 += (left0 + right0) * w + accum1 += (left1 + right1) * w + accum2 += (left2 + right2) * w + end + + # Center. + wc = GAUSS[HALO + 1] + accum0 += sData[1, lx, yy] * wc + accum1 += sData[2, lx, yy] * wc + accum2 += sData[3, lx, yy] * wc + + sScratch[1, tx, yy] = accum0 + sScratch[2, tx, yy] = accum1 + sScratch[3, tx, yy] = accum2 + end + end + @synchronize + + # 3) Vertical pass -> finalize dL/d(img). + if pix_x ≤ W && pix_y ≤ H + ly_v = ty + HALO + lx_v = tx + + sum0 = 0f0 + sum1 = 0f0 + sum2 = 0f0 + + @unroll for d in 1:HALO + w = GAUSS[HALO + 1 - d] + top0 = sScratch[1, lx_v, ly_v - d] + top1 = sScratch[2, lx_v, ly_v - d] + top2 = sScratch[3, lx_v, ly_v - d] + + bot0 = sScratch[1, lx_v, ly_v + d] + bot1 = sScratch[2, lx_v, ly_v + d] + bot2 = sScratch[3, lx_v, ly_v + d] + + sum0 += (top0 + bot0) * w + sum1 += (top1 + bot1) * w + sum2 += (top2 + bot2) * w + end + + # Center. + wc = GAUSS[HALO + 1] + sum0 += sScratch[1, lx_v, ly_v] * wc + sum1 += sScratch[2, lx_v, ly_v] * wc + sum2 += sScratch[3, lx_v, ly_v] * wc + + # Final accumulation. + dL_dpix = sum0 + 2f0 * p1 * sum1 + p2 * sum2 + dL_dimg[pix_x, pix_y, c, bz] = dL_dpix + end + @synchronize + end +end + +function _fused_ssim( + img::T; ref::T, C1::Float32 = 0.01f0^2, C2::Float32 = 0.03f0^2, train::Bool, +) where T <: AbstractArray{Float32, 4} + W, H, CH, B = size(img) + kab = get_backend(img) + + ssim_map = KA.zeros(kab, Float32, W, H, CH, B) + dm_dmu1 = train ? KA.zeros(kab, Float32, W, H, CH, B) : KA.zeros(kab, Float32, 0, 0, 0, 0) + dm_dsigma1_sq = train ? KA.zeros(kab, Float32, W, H, CH, B) : KA.zeros(kab, Float32, 0, 0, 0, 0) + dm_dsigma12 = train ? KA.zeros(kab, Float32, W, H, CH, B) : KA.zeros(kab, Float32, 0, 0, 0, 0) + + workgroupsize = (BLOCK_X, BLOCK_Y) + ndrange = (cld(W, BLOCK_X) * BLOCK_X, cld(H, BLOCK_Y) * BLOCK_Y, B) + _fused_ssim!(kab, workgroupsize)( + ssim_map, dm_dmu1, dm_dsigma1_sq, dm_dsigma12, + img, ref, C1, C2, train; ndrange) + + return ssim_map, dm_dmu1, dm_dsigma1_sq, dm_dsigma12 +end + +function fused_ssim_bwd( + img::T, ref::T, dL_dmap::T, + dm_dmu1::T, dm_dsigma1_sq::T, dm_dsigma12::T; + C1::Float32 = 0.01f0^2, C2::Float32 = 0.03f0^2, +) where T <: AbstractArray{Float32, 4} + W, H, CH, B = size(img) + kab = get_backend(img) + dL_dimg = KA.zeros(kab, Float32, W, H, CH, B) + + workgroupsize = (BLOCK_X, BLOCK_Y) + ndrange = (cld(W, BLOCK_X) * BLOCK_X, cld(H, BLOCK_Y) * BLOCK_Y, B) + _fused_ssim_bwd!(kab, workgroupsize)( + dL_dimg, img, ref, dL_dmap, + dm_dmu1, dm_dsigma1_sq, dm_dsigma12; ndrange) + return dL_dimg +end + +function fused_ssim(img::T; ref::T, C1::Float32 = 0.01f0^2, C2::Float32 = 0.03f0^2) where T <: AbstractArray{Float32, 4} + train = within_gradient(img) + y = _fused_ssim(img; ref, C1, C2, train) + return train ? y : y[1] +end + +function CRC.rrule(::typeof(_fused_ssim), + img::T; ref::T, C1::Float32 = 0.01f0^2, C2::Float32 = 0.03f0^2, train::Bool, +) where T <: AbstractArray{Float32, 4} + ssim_map, dm_dmu1, dm_dsigma1_sq, dm_dsigma12 = _fused_ssim(img; ref, C1, C2, train) + _pullback(Delta) = return CRC.NoTangent(), fused_ssim_bwd( + img, ref, CRC.unthunk(Delta), + dm_dmu1, dm_dsigma1_sq, dm_dsigma12; C1, C2) + return ssim_map, _pullback +end diff --git a/src/gui/gui.jl b/src/gui/gui.jl index 2d1eda8..2f51b3b 100644 --- a/src/gui/gui.jl +++ b/src/gui/gui.jl @@ -110,9 +110,7 @@ function resize_callback(_, width, height) end # Viewer-only mode. -function GSGUI(gaussians::GaussianModel, camera::Camera; gl_kwargs...) - kab = gpu_backend() - +function GSGUI(kab, gaussians::GaussianModel, camera::Camera; gl_kwargs...) NGL.init(3, 0) context = NGL.Context("GaussianSplatting.jl"; gl_kwargs...) NGL.set_resize_callback!(context, resize_callback) @@ -145,9 +143,7 @@ function GSGUI(gaussians::GaussianModel, camera::Camera; gl_kwargs...) end # Training mode. -function GSGUI(dataset_path::String, scale::Int; gl_kwargs...) - kab = gpu_backend() - +function GSGUI(kab, dataset_path::String, scale::Int; gl_kwargs...) NGL.init(3, 0) context = NGL.Context("GaussianSplatting.jl"; gl_kwargs...) NGL.set_resize_callback!(context, resize_callback) diff --git a/src/metrics.jl b/src/metrics.jl deleted file mode 100644 index 4c572a6..0000000 --- a/src/metrics.jl +++ /dev/null @@ -1,39 +0,0 @@ -mse(x, y) = mean((x .- y).^2) - -psnr(x, y) = 20f0 * log10(1f0 / sqrt(mse(x, y))) - -struct SSIM{W <: Flux.Conv} - window::W - c1::Float32 - c2::Float32 -end - -function SSIM(kab; channels::Int = 3, σ::Float32 = 1.5f0, window_size::Int = 11) - w = ImageFiltering.KernelFactors.gaussian(σ, window_size) - w2d = reshape(reshape(w, :, 1) * reshape(w, 1, :), window_size, window_size, 1) - window = reshape( - repeat(w2d, 1, 1, channels), - window_size, window_size, 1, channels) - - conv = Flux.Conv( - (window_size, window_size), channels => channels; - pad=(window_size ÷ 2, window_size ÷ 2), - groups=channels, bias=false) - copy!(conv.weight, window) - - SSIM(kab != CPU() ? Flux.gpu(conv) : conv, 0.01f0^2, 0.03f0^2) -end - -function (ssim::SSIM)(x::T, ref::T) where T - μ₁, μ₂ = ssim.window(x), ssim.window(ref) - μ₁², μ₂² = μ₁.^2, μ₂.^2 - μ₁₂ = μ₁ .* μ₂ - - σ₁² = ssim.window(x.^2) .- μ₁² - σ₂² = ssim.window(ref.^2) .- μ₂² - σ₁₂ = ssim.window(x .* ref) .- μ₁₂ - - l = ((2f0 .* μ₁₂ .+ ssim.c1) .* (2f0 .* σ₁₂ .+ ssim.c2)) ./ - ((μ₁² .+ μ₂² .+ ssim.c1) .* (σ₁² .+ σ₂² .+ ssim.c2)) - return mean(l) -end diff --git a/src/training.jl b/src/training.jl index 0040893..862675c 100644 --- a/src/training.jl +++ b/src/training.jl @@ -3,7 +3,6 @@ mutable struct Trainer{ R <: GaussianRasterizer, G <: GaussianModel, D <: ColmapDataset, - S <: SSIM, C <: GPUArrays.AllocCache, F, O, @@ -12,7 +11,6 @@ mutable struct Trainer{ gaussians::G dataset::D optimizers::O - ssim::S cache::C @@ -39,7 +37,6 @@ function Trainer( opacities=NU.Adam(kab, gs.opacities; lr=opt_params.lr_opacities, ϵ), scales=NU.Adam(kab, gs.scales; lr=opt_params.lr_scales, ϵ), rotations=NU.Adam(kab, gs.rotations; lr=opt_params.lr_rotations, ϵ)) - ssim = SSIM(kab) points_lr_scheduler = lr_exp_scheduler( opt_params.lr_points_start * dataset.camera_extent, @@ -50,7 +47,7 @@ function Trainer( densify = true step = 0 Trainer( - rast, gs, dataset, optimizers, ssim, cache, + rast, gs, dataset, optimizers, cache, points_lr_scheduler, opt_params, densify, step, ids) end @@ -120,7 +117,6 @@ end function validate(trainer::Trainer) gs = trainer.gaussians rast = trainer.rast - ssim = trainer.ssim dataset = trainer.dataset eval_ssim = 0f0 @@ -144,7 +140,7 @@ function validate(trainer::Trainer) image_tmp = permutedims(image, (2, 3, 1)) image_eval = reshape(image_tmp, size(image_tmp)..., 1) - eval_ssim += ssim(image_eval, target_image) + eval_ssim += mean(fused_ssim(image_eval; ref=target_image)) eval_mse += mse(image_eval, target_image) eval_psnr += psnr(image_eval, target_image) end @@ -160,7 +156,6 @@ function step!(trainer::Trainer) gs = trainer.gaussians rast = trainer.rast - ssim = trainer.ssim params = trainer.opt_params if trainer.step % 1000 == 0 && gs.sh_degree < gs.max_sh_degree @@ -199,7 +194,7 @@ function step!(trainer::Trainer) image_eval = reshape(image_tmp, size(image_tmp)..., 1) l1 = mean(abs.(image_eval .- target_image)) - s = 1f0 - ssim(image_eval, target_image) + s = 1f0 - mean(fused_ssim(image_eval; ref=target_image)) (1f0 - params.λ_dssim) * l1 + params.λ_dssim * s end diff --git a/src/utils.jl b/src/utils.jl index a914f58..7774481 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -30,3 +30,10 @@ function lr_exp_scheduler(lr_start::Float32, lr_end::Float32, steps::Int) end return _scheduler end + +mse(x, y) = mean((x .- y).^2) + +psnr(x, y) = 20f0 * log10(1f0 / sqrt(mse(x, y))) + +within_gradient(x) = false +CRC.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent()) diff --git a/test/Project.toml b/test/Project.toml index 42e150d..7a7884e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,13 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/runtests.jl b/test/runtests.jl index da5e800..9d822a5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,18 +3,21 @@ # under the terms of the LICENSE.md file. # ENV["GSP_TEST_AMDGPU"] = true -# ENV["GSP_TEST_CUDA"] = true +ENV["GSP_TEST_CUDA"] = true import Pkg if get(ENV, "GSP_TEST_AMDGPU", "false") == "true" @info "`GSP_TEST_AMDGPU` is `true`, importing AMDGPU.jl." - Pkg.develop("AMDGPU") + Pkg.add("AMDGPU") using AMDGPU + + kab = ROCBackend() elseif get(ENV, "GSP_TEST_CUDA", "false") == "true" - @info "`GSP_TEST_CUDA` is `true`, importing CUDA.jl & cuDNN.jl." - Pkg.add("CUDA") - Pkg.add("cuDNN") + @info "`GSP_TEST_CUDA` is `true`, importing CUDA.jl." + Pkg.add(["CUDA", "cuDNN"]) using CUDA, cuDNN + + kab = CUDABackend() else error("No GPU backend was specified.") end @@ -24,69 +27,105 @@ using Test using Zygote using LinearAlgebra using GaussianSplatting +using Statistics using StaticArrays using Quaternions using Rotations +using Flux +using ImageFiltering using GaussianSplatting: i32, u32 import KernelAbstractions as KA +struct SSIM{W <: Flux.Conv} + window::W + c1::Float32 + c2::Float32 +end + +function SSIM(kab; channels::Int = 3, σ::Float32 = 1.5f0, window_size::Int = 11) + w = ImageFiltering.KernelFactors.gaussian(σ, window_size) + w2d = reshape(reshape(w, :, 1) * reshape(w, 1, :), window_size, window_size, 1) + window = reshape( + repeat(w2d, 1, 1, channels), + window_size, window_size, 1, channels) + + conv = Flux.Conv( + (window_size, window_size), channels => channels; + pad=(window_size ÷ 2, window_size ÷ 2), + groups=channels, bias=false) + copy!(conv.weight, window) + + SSIM(kab != KA.CPU() ? Flux.gpu(conv) : conv, 0.01f0^2, 0.03f0^2) +end + +function (ssim::SSIM)(x::T, ref::T) where T + μ₁, μ₂ = ssim.window(x), ssim.window(ref) + μ₁², μ₂² = μ₁.^2, μ₂.^2 + μ₁₂ = μ₁ .* μ₂ + + σ₁² = ssim.window(x.^2) .- μ₁² + σ₂² = ssim.window(ref.^2) .- μ₂² + σ₁₂ = ssim.window(x .* ref) .- μ₁₂ + + l = ((2f0 .* μ₁₂ .+ ssim.c1) .* (2f0 .* σ₁₂ .+ ssim.c2)) ./ + ((μ₁² .+ μ₂² .+ ssim.c1) .* (σ₁² .+ σ₂² .+ ssim.c2)) + return mean(l) +end + DATASET = nothing GAUSSIANS = nothing -@info "Testing on `$(GaussianSplatting.gpu_backend())` backend." +@info "Testing on `$kab` backend." @testset "GaussianSplatting" begin -@testset "quat2mat" begin - r = RotXYZ(rand(Float32), rand(Float32), rand(Float32)) - q = QuatRotation{Float32}(r) +# @testset "quat2mat" begin +# r = RotXYZ(rand(Float32), rand(Float32), rand(Float32)) +# q = QuatRotation{Float32}(r) - ŷ = @inferred GaussianSplatting.unnorm_quat2rot(SVector{4, Float32}(q.w, q.x, q.y, q.z)) - y = SMatrix{3, 3, Float32, 9}(q) - @test all(ŷ .≈ y) -end +# ŷ = @inferred GaussianSplatting.unnorm_quat2rot(SVector{4, Float32}(q.w, q.x, q.y, q.z)) +# y = SMatrix{3, 3, Float32, 9}(q) +# @test all(ŷ .≈ y) +# end -@testset "get_rect" begin - width, height = 1024, 1024 - block = SVector{2, Int32}(16, 16) - grid = SVector{2, Int32}(cld(width, block[1]), cld(height, block[2])) - - # rect covering only one block - rmin, rmax = @inferred GaussianSplatting.get_rect( - SVector{2, Float32}(0, 0), 1i32, grid, block) - @test all(rmin .== (0, 0)) - @test all(rmax .== (1, 1)) - - # rect covering 2 blocks - rmin, rmax = @inferred GaussianSplatting.get_rect( - SVector{2, Float32}(0, 0), Int32(block[1] + 1), grid, block) - @test all(rmin .== (0, 0)) - @test all(rmax .== (2, 2)) -end +# @testset "get_rect" begin +# width, height = 1024, 1024 +# block = SVector{2, Int32}(16, 16) +# grid = SVector{2, Int32}(cld(width, block[1]), cld(height, block[2])) + +# # rect covering only one block +# rmin, rmax = @inferred GaussianSplatting.get_rect( +# SVector{2, Float32}(0, 0), 1i32, grid, block) +# @test all(rmin .== (0, 0)) +# @test all(rmax .== (1, 1)) + +# # rect covering 2 blocks +# rmin, rmax = @inferred GaussianSplatting.get_rect( +# SVector{2, Float32}(0, 0), Int32(block[1] + 1), grid, block) +# @test all(rmin .== (0, 0)) +# @test all(rmax .== (2, 2)) +# end -@testset "Tile ranges" begin - kab = GaussianSplatting.gpu_backend() - gaussian_keys = adapt(kab, - UInt64[0 << 32, 0 << 32, 1 << 32, 2 << 32, 3 << 32]) +# @testset "Tile ranges" begin +# gaussian_keys = adapt(kab, +# UInt64[0 << 32, 0 << 32, 1 << 32, 2 << 32, 3 << 32]) - ranges = KA.allocate(kab, UInt32, 2, 4) - fill!(ranges, 0u32) +# ranges = KA.allocate(kab, UInt32, 2, 4) +# fill!(ranges, 0u32) - GaussianSplatting.identify_tile_range!(kab, 256)( - ranges, gaussian_keys; ndrange=length(gaussian_keys)) - @test Array(ranges) == UInt32[0; 2;; 2; 3;; 3; 4;; 4; 5;;] -end +# GaussianSplatting.identify_tile_range!(kab, 256)( +# ranges, gaussian_keys; ndrange=length(gaussian_keys)) +# @test Array(ranges) == UInt32[0; 2;; 2; 3;; 3; 4;; 4; 5;;] +# end @testset "SSIM" begin - kab = GaussianSplatting.gpu_backend() - ssim = GaussianSplatting.SSIM(kab) + ssim = SSIM(kab) x = KA.ones(kab, Float32, (16, 16, 3, 1)) ref = KA.zeros(kab, Float32, (16, 16, 3, 1)) @test ssim(x, ref) ≈ 0f0 atol=1f-4 rtol=1f-4 - ref = KA.ones(kab, Float32, (16, 16, 3, 1)) @test ssim(x, ref) ≈ 1f0 @@ -96,121 +135,125 @@ end x[9:12, 13:16, :, :] .= 0.75f0 x[13:16, 13:16, :, :] .= 1f0 @test ssim(adapt(kab, x), ref) ≈ 0.1035 atol=1f-3 rtol=1f-3 + + x = adapt(kab, rand(Float32, 128, 128, 3, 2)) + ref = adapt(kab, rand(Float32, 128, 128, 3, 2)) + @test ssim(x, ref) ≈ mean(GaussianSplatting.fused_ssim(x; ref)) + + y, ∇ = Zygote.withgradient(x -> ssim(x, ref), x) + yf, ∇f = Zygote.withgradient(x -> mean(GaussianSplatting.fused_ssim(x; ref)), x) + @test y ≈ yf + @test ∇[1] ≈ ∇f[1] end -@testset "Dataset loading" begin - kab = GaussianSplatting.gpu_backend() - dataset_dir = joinpath(@__DIR__, "..", "assets", "bicycle-smol") - @assert isdir(dataset_dir) +# @testset "Dataset loading" begin +# dataset_dir = joinpath(@__DIR__, "..", "assets", "bicycle-smol") +# @assert isdir(dataset_dir) - dataset = GaussianSplatting.ColmapDataset(kab, dataset_dir; - scale=8, train_test_split=1.0, permute=false, verbose=false) - @test length(dataset) == 6 +# dataset = GaussianSplatting.ColmapDataset(kab, dataset_dir; +# scale=8, train_test_split=1.0, permute=false, verbose=false) +# @test length(dataset) == 6 - @test length(dataset.train_cameras) == 6 - cam = dataset.train_cameras[1] - (; width, height) = GaussianSplatting.resolution(cam) +# @test length(dataset.train_cameras) == 6 +# cam = dataset.train_cameras[1] +# (; width, height) = GaussianSplatting.resolution(cam) - img = GaussianSplatting.get_image(dataset, kab, 1, :train) - @test size(img, 1) == 3 - @test size(img)[2:3] == (width, height) - @test KA.get_backend(img) == kab +# img = GaussianSplatting.get_image(dataset, kab, 1, :train) +# @test size(img, 1) == 3 +# @test size(img)[2:3] == (width, height) +# @test KA.get_backend(img) == kab - global DATASET - if DATASET ≡ nothing - DATASET = dataset - end -end +# global DATASET +# if DATASET ≡ nothing +# DATASET = dataset +# end +# end -@testset "Gaussians creation" begin - kab = GaussianSplatting.gpu_backend() - dataset_dir = joinpath(@__DIR__, "..", "assets", "bicycle-smol") - @assert isdir(dataset_dir) - - global DATASET - dataset = if DATASET ≡ nothing - DATASET = GaussianSplatting.ColmapDataset(kab, dataset_dir; - scale=8, train_test_split=1.0, permute=false, verbose=false) - else - DATASET - end - - gaussians = GaussianSplatting.GaussianModel( - dataset.points, dataset.colors, dataset.scales; - max_sh_degree=3, isotropic=false) - @test length(gaussians) == size(dataset.points, 2) - - global GAUSSIANS - if GAUSSIANS ≡ nothing - GAUSSIANS = gaussians - end -end +# @testset "Gaussians creation" begin +# dataset_dir = joinpath(@__DIR__, "..", "assets", "bicycle-smol") +# @assert isdir(dataset_dir) -@testset "Fused rasterizer" begin - global DATASET - global GAUSSIANS - dataset = DATASET - @assert dataset ≢ nothing - gaussians = GAUSSIANS - @assert gaussians ≢ nothing - - camera = dataset.train_cameras[1] - (; width, height) = GaussianSplatting.resolution(camera) - - kab = GaussianSplatting.gpu_backend() - rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; - antialias=false, fused=true, mode=:rgbd) - - image_features = rasterizer( - gaussians.points, gaussians.opacities, gaussians.scales, - gaussians.rotations, gaussians.features_dc, gaussians.features_rest; - camera, sh_degree=gaussians.sh_degree) - @test size(image_features) == (4, width, height) -end +# global DATASET +# dataset = if DATASET ≡ nothing +# DATASET = GaussianSplatting.ColmapDataset(kab, dataset_dir; +# scale=8, train_test_split=1.0, permute=false, verbose=false) +# else +# DATASET +# end + +# gaussians = GaussianSplatting.GaussianModel( +# dataset.points, dataset.colors, dataset.scales; +# max_sh_degree=3, isotropic=false) +# @test length(gaussians) == size(dataset.points, 2) -@testset "Un-fused rasterizer" begin - global DATASET - global GAUSSIANS - dataset = DATASET - @assert dataset ≢ nothing - gaussians = GAUSSIANS - @assert gaussians ≢ nothing - - camera = dataset.train_cameras[1] - (; width, height) = GaussianSplatting.resolution(camera) - - kab = GaussianSplatting.gpu_backend() - rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; - antialias=false, fused=false, mode=:rgbd) - - image_features = rasterizer( - gaussians.points, gaussians.opacities, gaussians.scales, - gaussians.rotations, gaussians.features_dc, gaussians.features_rest; - camera, sh_degree=gaussians.sh_degree) - @test size(image_features) == (4, width, height) -end +# global GAUSSIANS +# if GAUSSIANS ≡ nothing +# GAUSSIANS = gaussians +# end +# end + +# @testset "Fused rasterizer" begin +# global DATASET +# global GAUSSIANS +# dataset = DATASET +# @assert dataset ≢ nothing +# gaussians = GAUSSIANS +# @assert gaussians ≢ nothing -@testset "Trainer w/ fused rasterizer" begin - global DATASET - global GAUSSIANS - dataset = DATASET - @assert dataset ≢ nothing - gaussians = GAUSSIANS - @assert gaussians ≢ nothing +# camera = dataset.train_cameras[1] +# (; width, height) = GaussianSplatting.resolution(camera) - camera = dataset.train_cameras[1] - (; width, height) = GaussianSplatting.resolution(camera) +# rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; +# antialias=false, fused=true, mode=:rgbd) - kab = GaussianSplatting.gpu_backend() - rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; - antialias=false, fused=true, mode=:rgbd) +# image_features = rasterizer( +# gaussians.points, gaussians.opacities, gaussians.scales, +# gaussians.rotations, gaussians.features_dc, gaussians.features_rest; +# camera, sh_degree=gaussians.sh_degree) +# @test size(image_features) == (4, width, height) +# end - opt_params = GaussianSplatting.OptimizationParams() - trainer = GaussianSplatting.Trainer(rasterizer, gaussians, dataset, opt_params) +# @testset "Un-fused rasterizer" begin +# global DATASET +# global GAUSSIANS +# dataset = DATASET +# @assert dataset ≢ nothing +# gaussians = GAUSSIANS +# @assert gaussians ≢ nothing - loss = GaussianSplatting.step!(trainer) - @test loss > 0 -end +# camera = dataset.train_cameras[1] +# (; width, height) = GaussianSplatting.resolution(camera) + +# rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; +# antialias=false, fused=false, mode=:rgbd) + +# image_features = rasterizer( +# gaussians.points, gaussians.opacities, gaussians.scales, +# gaussians.rotations, gaussians.features_dc, gaussians.features_rest; +# camera, sh_degree=gaussians.sh_degree) +# @test size(image_features) == (4, width, height) +# end + +# @testset "Trainer w/ fused rasterizer" begin +# global DATASET +# global GAUSSIANS +# dataset = DATASET +# @assert dataset ≢ nothing +# gaussians = GAUSSIANS +# @assert gaussians ≢ nothing + +# camera = dataset.train_cameras[1] +# (; width, height) = GaussianSplatting.resolution(camera) + +# rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; +# antialias=false, fused=true, mode=:rgbd) + +# opt_params = GaussianSplatting.OptimizationParams() +# trainer = GaussianSplatting.Trainer(rasterizer, gaussians, dataset, opt_params) + +# loss = GaussianSplatting.step!(trainer) +# @test loss > 0 +# end # @testset "Trainer w/ un-fused rasterizer" begin # global DATASET @@ -223,7 +266,6 @@ end # camera = dataset.train_cameras[1] # (; width, height) = GaussianSplatting.resolution(camera) -# kab = GaussianSplatting.gpu_backend() # rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; # antialias=false, fused=false, mode=:rgbd)