Skip to content

Commit 93d7d35

Browse files
committed
fix missing imports & add more warnings
1 parent c42f15e commit 93d7d35

File tree

4 files changed

+211
-3
lines changed

4 files changed

+211
-3
lines changed

fast_gauss/base_utils.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from __future__ import annotations
2+
from copy import copy
3+
from typing import Mapping, TypeVar, Union, Iterable, Callable, Dict, List
4+
# these are generic type vars to tell mapping to accept any type vars when creating a type
5+
KT = TypeVar("KT") # key type
6+
VT = TypeVar("VT") # value type
7+
8+
# TODO: move this to engine implementation
9+
# TODO: this is a special type just like Config
10+
# ? However, dotdict is a general purpose data passing object, instead of just designed for config
11+
# The only reason we defined those special variables are for type annotations
12+
# If removed, all will still work flawlessly, just no editor annotation for output, type and meta
13+
14+
15+
def return_dotdict(func: Callable):
16+
def inner(*args, **kwargs):
17+
return dotdict(func(*args, **kwargs))
18+
return inner
19+
20+
21+
class DoNothing:
22+
def __getattr__(self, name):
23+
def method(*args, **kwargs):
24+
pass
25+
return method
26+
27+
28+
class dotdict(dict, Dict[KT, VT]):
29+
"""
30+
This is the default data passing object used throughout the codebase
31+
Main function: dot access for dict values & dict like merging and updates
32+
33+
a dictionary that supports dot notation
34+
as well as dictionary access notation
35+
usage: d = make_dotdict() or d = make_dotdict{'val1':'first'})
36+
set attributes: d.val2 = 'second' or d['val2'] = 'second'
37+
get attributes: d.val2 or d['val2']
38+
"""
39+
40+
def update(self, dct: Dict = None, **kwargs):
41+
dct = copy(dct) # avoid modifying the original dict, use super's copy to avoid recursion
42+
43+
# Handle different arguments
44+
if dct is None:
45+
dct = kwargs
46+
elif isinstance(dct, Mapping):
47+
dct.update(kwargs)
48+
else:
49+
super().update(dct, **kwargs)
50+
return
51+
52+
# Recursive updates
53+
for k, v in dct.items():
54+
if k in self:
55+
56+
# Handle type conversions
57+
target_type = type(self[k])
58+
if not isinstance(v, target_type):
59+
# NOTE: bool('False') will be True
60+
if target_type == bool and isinstance(v, str):
61+
dct[k] = v == 'True'
62+
else:
63+
dct[k] = target_type(v)
64+
65+
if isinstance(v, dict):
66+
self[k].update(v) # recursion from here
67+
else:
68+
self[k] = v
69+
else:
70+
if isinstance(v, dict):
71+
self[k] = dotdict(v) # recursion?
72+
elif isinstance(v, list):
73+
self[k] = [dotdict(x) if isinstance(x, dict) else x for x in v]
74+
else:
75+
self[k] = v
76+
return self
77+
78+
def __init__(self, *args, **kwargs):
79+
self.update(*args, **kwargs)
80+
81+
copy = return_dotdict(dict.copy)
82+
fromkeys = return_dotdict(dict.fromkeys)
83+
84+
# def __hash__(self):
85+
# # return hash(''.join([str(self.values().__hash__())]))
86+
# return super(dotdict, self).__hash__()
87+
88+
# def __init__(self, *args, **kwargs):
89+
# super(dotdict, self).__init__(*args, **kwargs)
90+
91+
"""
92+
Uncomment following lines and
93+
comment out __getattr__ = dict.__getitem__ to get feature:
94+
95+
returns empty numpy array for undefined keys, so that you can easily copy things around
96+
TODO: potential caveat, harder to trace where this is set to np.array([], dtype=np.float32)
97+
"""
98+
99+
def __getitem__(self, key):
100+
try:
101+
return dict.__getitem__(self, key)
102+
except KeyError as e:
103+
raise AttributeError(e)
104+
# MARK: Might encounter exception in newer version of pytorch
105+
# Traceback (most recent call last):
106+
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 245, in _feed
107+
# obj = _ForkingPickler.dumps(obj)
108+
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
109+
# cls(buf, protocol).dump(obj)
110+
# KeyError: '__getstate__'
111+
# MARK: Because you allow your __getattr__() implementation to raise the wrong kind of exception.
112+
# FIXME: not working typing hinting code
113+
__getattr__: Callable[..., 'torch.Tensor'] = __getitem__ # type: ignore # overidden dict.__getitem__
114+
__getattribute__: Callable[..., 'torch.Tensor'] # type: ignore
115+
# __getattr__ = dict.__getitem__
116+
__setattr__ = dict.__setitem__
117+
__delattr__ = dict.__delitem__
118+
119+
# TODO: better ways to programmically define these special variables?
120+
121+
@property
122+
def meta(self) -> dotdict:
123+
# Special variable used for storing cpu tensor in batch
124+
if 'meta' not in self:
125+
self.meta = dotdict()
126+
return self.__getitem__('meta')
127+
128+
@meta.setter
129+
def meta(self, meta):
130+
self.__setitem__('meta', meta)
131+
132+
@property
133+
def output(self) -> dotdict: # late annotation needed for this
134+
# Special entry for storing output tensor in batch
135+
if 'output' not in self:
136+
self.output = dotdict()
137+
return self.__getitem__('output')
138+
139+
@output.setter
140+
def output(self, output):
141+
self.__setitem__('output', output)
142+
143+
@property
144+
def persistent(self) -> dotdict: # late annotation needed for this
145+
# Special entry for storing persistent tensor in batch
146+
if 'persistent' not in self:
147+
self.persistent = dotdict()
148+
return self.__getitem__('persistent')
149+
150+
@persistent.setter
151+
def persistent(self, persistent):
152+
self.__setitem__('persistent', persistent)
153+
154+
@property
155+
def type(self) -> str: # late annotation needed for this
156+
# Special entry for type based construction system
157+
return self.__getitem__('type')
158+
159+
@type.setter
160+
def type(self, type):
161+
self.__setitem__('type', type)
162+
163+
def to_dict(self):
164+
out = dict()
165+
for k, v in self.items():
166+
if isinstance(v, dotdict):
167+
v = v.to_dict() # recursion point
168+
out[k] = v
169+
return out
170+
171+
172+
class default_dotdict(dotdict):
173+
def __init__(self, default_type=object, *arg, **kwargs):
174+
super().__init__(*arg, **kwargs)
175+
dict.__setattr__(self, 'default_type', default_type)
176+
177+
def __getitem__(self, key):
178+
try:
179+
return super().__getitem__(key)
180+
except (AttributeError, KeyError) as e:
181+
super().__setitem__(key, dict.__getattribute__(self, 'default_type')())
182+
return super().__getitem__(key)

fast_gauss/console_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class Colors:
8080
from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, filesize, ProgressColumn
8181
from tqdm.std import tqdm as std_tqdm
8282
from tqdm.rich import tqdm_rich, FractionColumn, RateColumn
83-
from easyvolcap.utils.base_utils import default_dotdict, dotdict, DoNothing
83+
from .base_utils import default_dotdict, dotdict, DoNothing
8484

8585
pdbr_theme = 'ansi_dark'
8686
pdbr.utils.set_traceback(pdbr_theme)
@@ -787,3 +787,11 @@ def build_parser(d: dict, parser: argparse.ArgumentParser = None, **kwargs):
787787
parser.add_argument(f'--{k}', type=type(v), default=v, help=markup_to_ansi(help_pattern.format(v)))
788788

789789
return parser
790+
791+
792+
def warn_once(message: str):
793+
if not hasattr(warn_once, 'warned'):
794+
warn_once.warned = set()
795+
if message not in warn_once.warned:
796+
log(yellow(message))
797+
warn_once.warned.add(message)

fast_gauss/gsplat_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def __init__(self,
5454
self.resize_buffers(init_buffer_size)
5555
self.resize_textures(*init_texture_size)
5656

57+
log(green_slim(f'GSplatContextManager initialized with attribute dtype: {self.dtype}, texture dtype: {self.tex_dtype}, offline rendering: {self.offline_rendering}, buffer size: {init_buffer_size}, texture size: {init_texture_size}'))
58+
59+
if not self.offline_rendering:
60+
log(green_slim('Using online rendering mode, in this mode, calling the rendering function of fast_gauss will write directly to the currently bound framebuffer'))
61+
log(green_slim('In this mode, the output of all rasterization calls will be None (same output count). Please do not perform further processing on them.'))
62+
log(green_slim('Please make sure to set up the correct GUI environment before calling the rasterization function, see more in readme.md'))
63+
5764
def opengl_options(self):
5865
# Performs face culling
5966
gl.glDisable(gl.GL_CULL_FACE)
@@ -220,6 +227,13 @@ def resize_buffers(self, v: int = 0):
220227

221228
@torch.no_grad()
222229
def render(self, xyz3: torch.Tensor, cov6: torch.Tensor, rgb3: torch.Tensor, occ1: torch.Tensor, raster_settings: 'GaussianRasterizationSettings'):
230+
if xyz3.dtype != self.dtype:
231+
warn_once(yellow(f'Input tensors has dtype {xyz3.dtype}, expected {self.dtype}, will cast to {self.dtype}'))
232+
xyz3, cov6, rgb3, occ1 = xyz3.to(self.dtype), cov6.to(self.dtype), rgb3.to(self.dtype), occ1.to(self.dtype)
233+
for key in raster_settings:
234+
if isinstance(raster_settings[key], torch.Tensor):
235+
raster_settings[key] = raster_settings[key].to(self.dtype)
236+
223237
# Prepare OpenGL texture size
224238
H, W = raster_settings.image_height, raster_settings.image_width
225239
self.resize_textures(H, W)
@@ -237,7 +251,7 @@ def render(self, xyz3: torch.Tensor, cov6: torch.Tensor, rgb3: torch.Tensor, occ
237251

238252
# Upload sorted data to OpenGL for rendering
239253
from cuda import cudart
240-
from easyvolcap.utils.cuda_utils import CHECK_CUDART_ERROR, FORMAT_CUDART_ERROR
254+
from .cuda_utils import CHECK_CUDART_ERROR, FORMAT_CUDART_ERROR
241255
kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice
242256

243257
CHECK_CUDART_ERROR(cudart.cudaGraphicsMapResources(1, self.cu_vbo, torch.cuda.current_stream().cuda_stream))
@@ -293,6 +307,10 @@ def render(self, xyz3: torch.Tensor, cov6: torch.Tensor, rgb3: torch.Tensor, occ
293307
torch.cuda.current_stream().cuda_stream)) # stream
294308
CHECK_CUDART_ERROR(cudart.cudaGraphicsUnmapResources(1, cu_tex, torch.cuda.current_stream().cuda_stream))
295309

310+
if rgba_map.dtype != xyz3.dtype:
311+
warn_once(yellow(f'Using texture dtype {rgba_map.dtype}, expected {xyz3.dtype} for the output, will cast to {xyz3.dtype}'))
312+
rgba_map = rgba_map.to(xyz3.dtype)
313+
296314
return rgba_map # H, W, 4
297315
else:
298316
return None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "fast_gauss"
7-
version = "0.0.5"
7+
version = "0.0.6"
88
description = "A geometry-shader-based, global CUDA sorted high-performance 3D Gaussian Splatting rasterizer. Can achieve a 5-10x speedup in rendering compared to the vanialla diff-gaussian-rasterization."
99
readme = "readme.md"
1010
license = { file = "license" }

0 commit comments

Comments
 (0)