Skip to content

Commit ef3bdaf

Browse files
authored
Add Metal support (#31)
1 parent 29ea327 commit ef3bdaf

File tree

9 files changed

+51
-33
lines changed

9 files changed

+51
-33
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "GaussianSplatting"
22
uuid = "991e6b22-92f0-46ac-8b70-8a93e7beee5d"
3-
version = "1.1.0"
3+
version = "1.2.0"
44
authors = ["Anton Smirnov <[email protected]>"]
55

66
[deps]
7+
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
910
CImGui = "5d785b6c-b76f-510e-a07c-3070796c7e87"
@@ -38,13 +39,16 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3839
[weakdeps]
3940
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4041
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
42+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4143

4244
[extensions]
4345
GaussianSplattingAMDGPUExt = "AMDGPU"
4446
GaussianSplattingCUDAExt = "CUDA"
47+
GaussianSplattingMetalExt = "Metal"
4548

4649
[compat]
4750
AMDGPU = "2"
51+
AcceleratedKernels = "0.4.3"
4852
Adapt = "4"
4953
BSON = "0.3"
5054
CImGui = "6"
@@ -61,10 +65,11 @@ ImageIO = "0.6"
6165
ImageMagick = "1.3"
6266
ImageTransformations = "0.10"
6367
KernelAbstractions = "0.9.34"
68+
Metal = "1.9"
6469
ModernGL = "1.1"
6570
NearestNeighbors = "0.4"
6671
NerfUtils = "0.2"
67-
NeuralGraphicsGL = "0.5"
72+
NeuralGraphicsGL = "0.5.1"
6873
PlyIO = "1.2"
6974
Quaternions = "0.7"
7075
Rotations = "1.7"

docs/src/index.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ Gaussian Splatting algorithm in pure Julia.
77
## Requirements
88

99
- Julia 1.10 or higher.
10-
- AMD ([AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl)) or
11-
Nvidia ([CUDA.jl](https://github.com/JuliaGPU/CUDA.jl)) capable machine.
10+
- [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl) or
11+
[CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) or
12+
[Metal.jl](https://github.com/JuliaGPU/Metal.jl) capable machine.
1213

1314
## Install
1415

@@ -24,14 +25,16 @@ GaussianSplatting.jl comes with a GUI application to train & view the gaussians.
2425

2526
1. Add necessary packages:
2627
```julia
27-
] add AMDGPU # for AMD GPU
28-
] add CUDA # for Nvidia GPU
28+
] add AMDGPU # for AMD GPU
29+
] add CUDA # for Nvidia GPU
30+
] add Metal # for Apple GPU
2931
```
3032

3133
2. Run:
3234
```julia
33-
julia> using AMDGPU; kab = ROCBackend() # for AMD GPU
34-
julia> using CUDA; kab = CUDABackend() # for Nvidia GPU
35+
julia> using AMDGPU; kab = ROCBackend() # for AMD GPU
36+
julia> using CUDA; kab = CUDABackend() # for Nvidia GPU
37+
julia> using Metal; kab = MetalBackend() # for Apple GPU
3538
julia> GaussianSplatting.gui(kab, "path-to-colmap-dataset-directory"; scale=1)
3639
```
3740

ext/GaussianSplattingAMDGPUExt/GaussianSplattingAMDGPUExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ GaussianSplatting.base_array_type(::ROCBackend) = ROCArray
1212

1313
GaussianSplatting.use_ak(::ROCBackend) = true
1414

15-
function GaussianSplatting.allocate_pinned(kab, ::Type{T}, shape) where T
15+
function GaussianSplatting.allocate_pinned(::ROCBackend, ::Type{T}, shape) where T
1616
x = Array{T}(undef, shape)
1717
xd = unsafe_wrap(ROCArray, pointer(x), size(x))
1818
return x, xd

ext/GaussianSplattingMetalExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module GaussianSplattingMetalExt
2+
3+
using Metal
4+
using GaussianSplatting
5+
6+
GaussianSplatting.base_array_type(::MetalBackend) = MtlArray
7+
8+
GaussianSplatting.use_ak(::MetalBackend) = true
9+
10+
function GaussianSplatting.allocate_pinned(::MetalBackend, ::Type{T}, shape) where T
11+
xd = MtlArray{T, length(shape), Metal.SharedStorage}(undef, shape)
12+
x = reshape(unsafe_wrap(Vector{T}, reshape(xd, :)), shape)
13+
return x, xd
14+
end
15+
16+
# Unregistered automatically in the array dtor.
17+
GaussianSplatting.unpin_memory(::MtlArray) = return
18+
19+
end

src/GaussianSplatting.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using GLFW
3131

3232
import CImGui.lib as iglib
3333

34+
import AcceleratedKernels as AK
3435
import BSON
3536
import ChainRulesCore as CRC
3637
import ImageFiltering

src/gui/gui.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ end
111111

112112
# Viewer-only mode.
113113
function GSGUI(kab, gaussians::GaussianModel, camera::Camera; gl_kwargs...)
114-
NGL.init(3, 0)
114+
NGL.init(3, 2)
115115
context = NGL.Context("GaussianSplatting.jl"; gl_kwargs...)
116116
NGL.set_resize_callback!(context, resize_callback)
117117

@@ -144,7 +144,7 @@ end
144144

145145
# Training mode.
146146
function GSGUI(kab, dataset_path::String, scale::Int; gl_kwargs...)
147-
NGL.init(3, 0)
147+
NGL.init(3, 2)
148148
context = NGL.Context("GaussianSplatting.jl"; gl_kwargs...)
149149
NGL.set_resize_callback!(context, resize_callback)
150150

src/rasterization/projection.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function project(
44
rotations::AbstractMatrix{Float32};
55
rast::GaussianRasterizer, camera::Camera,
66
near_plane::Float32, far_plane::Float32,
7-
radius_clip::Float32, blur_ϵ::Float32,
7+
radius_clip::Int32, blur_ϵ::Float32,
88
)
99
(; width, height) = resolution(camera)
1010
@assert width % 16 == 0 && height % 16 == 0
@@ -57,7 +57,7 @@ function ∇project(
5757
conics::AbstractMatrix{Float32};
5858
rast::GaussianRasterizer, camera::Camera,
5959
near_plane::Float32, far_plane::Float32,
60-
radius_clip::Float32, blur_ϵ::Float32,
60+
radius_clip::Int32, blur_ϵ::Float32,
6161
)
6262
K = camera.intrinsics
6363
R_w2c = SMatrix{3, 3, Float32}(camera.w2c[1:3, 1:3])
@@ -110,7 +110,7 @@ function ChainRulesCore.rrule(::typeof(project),
110110

111111
rast::GaussianRasterizer, camera::Camera,
112112
near_plane::Float32, far_plane::Float32,
113-
radius_clip::Float32, blur_ϵ::Float32,
113+
radius_clip::Int32, blur_ϵ::Float32,
114114
)
115115
means_2d, conics, compensations, depths = project(
116116
means_3d, scales, rotations;
@@ -151,7 +151,7 @@ end
151151
# Config.
152152
near_plane::Float32,
153153
far_plane::Float32,
154-
radius_clip::Float32,
154+
radius_clip::Int32,
155155
blur_ϵ::Float32,
156156
) where {C <: Maybe{AbstractMatrix{Float32}}, RM}
157157
i = @index(Global)
@@ -194,10 +194,10 @@ end
194194

195195
# Discard Gaussians outside of image plane.
196196
if (
197-
(mean_2D[1] + radius) 0 ||
198-
(mean_2D[1] - radius) resolution[1] ||
199-
(mean_2D[2] + radius) 0 ||
200-
(mean_2D[2] - radius) resolution[2]
197+
(mean_2D[1] + radius) 0f0 ||
198+
(mean_2D[1] - radius) Float32(resolution[1]) ||
199+
(mean_2D[2] + radius) 0f0 ||
200+
(mean_2D[2] - radius) Float32(resolution[2])
201201
)
202202
radii[i] = 0i32
203203
return

src/rasterization/rasterizer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ function (rast::GaussianRasterizer)(
192192
means_2d, conics, compensations, depths = project(
193193
means_3d, scales_act, rotations;
194194
rast, camera, near_plane=0.2f0, far_plane=1000f0,
195-
radius_clip=3f0, blur_ϵ=0.3f0)
195+
radius_clip=Int32(3), blur_ϵ=0.3f0)
196196

197197
colors = spherical_harmonics(means_3d, shs; rast, camera, sh_degree)
198198

@@ -252,7 +252,7 @@ function rasterize(
252252

253253
# TODO make configurable.
254254
near_plane, far_plane = 0.2f0, 1000f0
255-
radius_clip = 3f0 # In pixels.
255+
radius_clip = Int32(3) # In pixels.
256256
blur_ϵ = 0.3f0
257257

258258
project!(kab)(
@@ -315,7 +315,7 @@ function rasterize(
315315
rast.gstate.radii, rast.grid, BLOCK; ndrange=n)
316316

317317
if use_ak(kab)
318-
sortperm!(
318+
AK.sortperm!(
319319
@view(rast.bstate.permutation[1:n_rendered]),
320320
@view(rast.bstate.gaussian_keys_unsorted[1:n_rendered]);
321321
temp=@view(rast.bstate.permutation_tmp[1:n_rendered]))

src/rasterization/utils.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ sdiagm(x, y, z) = SMatrix{3, 3, Float32, 9}(
1313

1414
gpu_floor(T, x) = unsafe_trunc(T, floor(x))
1515
gpu_ceil(T, x) = unsafe_trunc(T, ceil(x))
16-
17-
gpu_cld(x, y::T) where T = (x + y - one(T)) ÷ y
16+
gpu_cld(x::X, y::T) where {X, T} = unsafe_trunc(T, floor(Float32(x + y - one(X)) / Float32(y)))
1817

1918
Base.@propagate_inbounds function get_rect(
2019
pixel::SVector{2, Float32}, max_radius::Int32,
@@ -26,15 +25,6 @@ Base.@propagate_inbounds function get_rect(
2625
@inbounds rmax = SVector{2, Int32}(
2726
clamp(gpu_floor(Int32, gpu_cld(pixel[1] + max_radius, block[1])), 0i32, grid[1]),
2827
clamp(gpu_floor(Int32, gpu_cld(pixel[2] + max_radius, block[2])), 0i32, grid[2]))
29-
30-
# rblock = inv.(block)
31-
32-
# rmin = gpu_floor.(Int32, (pixel .- max_radius) .* rblock)
33-
# rmin = clamp.(rmin, 0i32, grid)
34-
35-
# rmax = gpu_ceil.(Int32, (pixel .+ max_radius) .* rblock)
36-
# # rmax = gpu_cld.(pixel .+ max_radius, block)
37-
# rmax = clamp.(rmax, 0i32, grid)
3828
return rmin, rmax
3929
end
4030

0 commit comments

Comments
 (0)