Skip to content

Commit d51c8d3

Browse files
committed
Add PLY export & fix format
1 parent a8d95a1 commit d51c8d3

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

src/GaussianSplatting.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ end
154154
function gui(path::String; scale::Maybe{Int} = nothing, fullscreen::Bool = false)
155155
ispath(path) || error("Path does not exist: `$path`.")
156156

157-
viewer_mode = endswith(path, ".bson")
157+
viewer_mode = endswith(path, ".bson") || endswith(path, ".ply")
158158
!viewer_mode && !isdir(path) && error(
159-
"`path` must be either a `.bson` model checkpoint or " *
159+
"`path` must be either a [`.bson`, '.ply'] model checkpoint or " *
160160
"a directory with COLMAP dataset, instead: `$path`.")
161161
!viewer_mode && scale nothing && error(
162162
"`scale` keyword argument must be specified if `path` is a COLMAP dataset.")
@@ -166,10 +166,18 @@ function gui(path::String; scale::Maybe{Int} = nothing, fullscreen::Bool = false
166166
(1024, 1024, true)
167167

168168
gui = if viewer_mode
169-
θ = BSON.load(path)
170-
gaussians = GaussianModel(gpu_backend())
171-
set_from_bson!(gaussians, θ[:gaussians])
172-
camera = θ[:camera]
169+
kab = gpu_backend()
170+
if endswith(path, ".bson")
171+
θ = BSON.load(path)
172+
gaussians = GaussianModel(kab)
173+
set_from_bson!(gaussians, θ[:gaussians])
174+
camera = θ[:camera]
175+
else
176+
gaussians = import_ply(path, kab)
177+
width = 1024
178+
fov = NU.fov2focal(1024, 45f0)
179+
camera = Camera(; fx=fov, fy=fov, width, height=width)
180+
end
173181

174182
GSGUI(gaussians, camera; width, height, fullscreen, resizable)
175183
else

src/gaussians.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ sh_2_rgb(x) = x * SH0 + 0.5f0
157157

158158
inverse_sigmoid(x) = log(x / (1f0 - x))
159159

160-
function export_ply!(g::GaussianModel, filename::String)
160+
function export_ply(g::GaussianModel, filename::String)
161161
ply = PlyIO.Ply()
162162

163163
n = size(g.points, 2)
@@ -174,16 +174,16 @@ function export_ply!(g::GaussianModel, filename::String)
174174
PlyIO.ArrayProperty("y", xyz[2, :]),
175175
PlyIO.ArrayProperty("z", xyz[3, :]),
176176

177-
[PlyIO.ArrayProperty("f_dc_$i", features_dc[i, :]) for i in 1:size(features_dc, 1)]...,
178-
[PlyIO.ArrayProperty("f_rest_$i", features_rest[i, :]) for i in 1:size(features_rest, 1)]...,
179-
[PlyIO.ArrayProperty("scale_$i", scales[i, :]) for i in 1:size(scales, 1)]...,
180-
[PlyIO.ArrayProperty("rot_$i", rotations[i, :]) for i in 1:size(rotations, 1)]...,
177+
[PlyIO.ArrayProperty("f_dc_$(i - 1)", features_dc[i, :]) for i in 1:size(features_dc, 1)]...,
178+
[PlyIO.ArrayProperty("f_rest_$(i - 1)", features_rest[i, :]) for i in 1:size(features_rest, 1)]...,
179+
[PlyIO.ArrayProperty("scale_$(i - 1)", scales[i, :]) for i in 1:size(scales, 1)]...,
180+
[PlyIO.ArrayProperty("rot_$(i - 1)", rotations[i, :]) for i in 1:size(rotations, 1)]...,
181181

182182
PlyIO.ArrayProperty("opacity", opacities),
183183
)
184184
push!(ply, vertex)
185185

186-
PlyIO.save_ply(ply, filename)
186+
PlyIO.save_ply(ply, filename; ascii=false)
187187
return
188188
end
189189

@@ -196,14 +196,14 @@ function import_ply(filename::String, kab)
196196

197197
n = length(vertex["x"])
198198
xyz = vcat([reshape(vertex[i], 1, n) for i in ("x", "y", "z")]...)
199-
scales = vcat([reshape(vertex["scale_$i"], 1, n) for i in 1:3]...)
200-
rotations = vcat([reshape(vertex["rot_$i"], 1, n) for i in 1:4]...)
199+
scales = vcat([reshape(vertex["scale_$(i - 1)"], 1, n) for i in 1:3]...)
200+
rotations = vcat([reshape(vertex["rot_$(i - 1)"], 1, n) for i in 1:4]...)
201201
opacities = reshape(Array(vertex["opacity"]), 1, n)
202202

203-
features_dc = vcat([reshape(vertex["f_dc_$i"], 1, 1, n) for i in 1:3]...)
203+
features_dc = vcat([reshape(vertex["f_dc_$(i - 1)"], 1, 1, n) for i in 1:3]...)
204204
features_rest = if n_frest > 0
205205
reshape(
206-
vcat([reshape(vertex["f_rest_$i"], 1, n) for i in 1:n_frest]...),
206+
vcat([reshape(vertex["f_rest_$(i - 1)"], 1, n) for i in 1:n_frest]...),
207207
3, :, n)
208208
else
209209
Array{Float32}(undef, 3, 0, n)

src/gui/gui.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,17 @@ function handle_ui!(gui::GSGUI; frame_time)
354354
end
355355
CImGui.Separator()
356356

357+
if CImGui.Button("Export PLY", CImGui.ImVec2(-1, 0))
358+
save_dir = unsafe_string(pointer(gui.ui_state.save_directory_path))
359+
isdir(save_dir) || mkpath(save_dir)
360+
361+
tstmp = now()
362+
fmt = "timestamp-$(month(tstmp))M-$(day(tstmp))D-$(hour(tstmp)):$(minute(tstmp))"
363+
save_file = joinpath(save_dir, "state-(step-$(gui.trainer.step))-($fmt).ply")
364+
export_ply(gui.gaussians, save_file)
365+
end
366+
CImGui.Separator()
367+
357368
CImGui.Text("Path to State File (.bson):")
358369
CImGui.PushItemWidth(-1)
359370
CImGui.InputText(

0 commit comments

Comments
 (0)