1+ import numpy as np
2+
3+ import torch
4+ import torch .nn as nn
5+ from torch .autograd import Function
6+ from torch .autograd .function import once_differentiable
7+ from torch .cuda .amp import custom_bwd , custom_fwd
8+
9+ try :
10+ import _gridencoder as _backend
11+ except ImportError :
12+ from .backend import _backend
13+
14+ _gridtype_to_id = {
15+ 'hash' : 0 ,
16+ 'tiled' : 1 ,
17+ }
18+
19+ _interp_to_id = {
20+ 'linear' : 0 ,
21+ 'smoothstep' : 1 ,
22+ }
23+
24+ class STE_binary (torch .autograd .Function ):
25+ @staticmethod
26+ def forward (ctx , input ):
27+ ctx .save_for_backward (input )
28+ # out = torch.sign(input)
29+ p = (input >= 0 ) * (+ 1.0 )
30+ n = (input < 0 ) * (- 1.0 )
31+ out = p + n
32+ return out
33+ @staticmethod
34+ def backward (ctx , grad_output ):
35+ # mask: to ensure x belongs to (-1, 1)
36+ input , = ctx .saved_tensors
37+ i2 = input .clone ().detach ()
38+ i3 = torch .clamp (i2 , - 1 , 1 )
39+ mask = (i3 == i2 ) + 0.0
40+ return grad_output * mask
41+
42+
43+ class STE_multistep (torch .autograd .Function ):
44+ @staticmethod
45+ def forward (ctx , input , Q ):
46+ return torch .round (input / Q )* Q
47+ @staticmethod
48+ def backward (ctx , grad_output ):
49+ return grad_output , None
50+
51+
52+ class _grid_encode (Function ):
53+ @staticmethod
54+ @custom_fwd
55+ def forward (ctx , inputs , embeddings , offsets_list , resolutions_list , calc_grad_inputs = False , max_level = None ):
56+ # inputs: [N, num_dim], float in [0, 1]
57+ # embeddings: [sO, n_features], float. self.embeddings = nn.Parameter(torch.empty(offset, n_features))
58+ # offsets_list: [n_levels + 1], int
59+ # RETURN: [N, F], float
60+
61+ inputs = inputs .contiguous ()
62+
63+ N , num_dim = inputs .shape # batch size, coord dim # N_rays, 3
64+ n_levels = offsets_list .shape [0 ] - 1 # level # 层数=16
65+ n_features = embeddings .shape [1 ] # embedding dim for each level # 就是channel数=2
66+
67+ max_level = n_levels if max_level is None else min (max_level , n_levels )
68+
69+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
70+ # if n_features % 2 != 0, force float, since half for atomicAdd is very slow.
71+ if torch .is_autocast_enabled () and n_features % 2 == 0 :
72+ embeddings = embeddings .to (torch .half )
73+
74+ # n_levels first, optimize cache for cuda kernel, but needs an extra permute later
75+ outputs = torch .empty (n_levels , N , n_features , device = inputs .device , dtype = embeddings .dtype ) # 创建一个buffer给cuda填充
76+ # outputs = [hash层数=16, N_rays, channels=2]
77+
78+ # zero init if we only calculate partial levels
79+ if max_level < n_levels : outputs .zero_ ()
80+
81+ if calc_grad_inputs : # inputs.requires_grad
82+ dy_dx = torch .empty (N , n_levels * num_dim * n_features , device = inputs .device , dtype = embeddings .dtype )
83+ if max_level < n_levels : dy_dx .zero_ ()
84+ else :
85+ dy_dx = None
86+
87+ _backend .grid_encode_forward (
88+ inputs ,
89+ embeddings ,
90+ offsets_list ,
91+ resolutions_list ,
92+ outputs ,
93+ N , num_dim , n_features , n_levels , max_level ,
94+ dy_dx
95+ )
96+
97+ # permute back to [N, n_levels * n_features] # [N_rays, hash层数=16 * channels=2]
98+ outputs = outputs .permute (1 , 0 , 2 ).reshape (N , n_levels * n_features )
99+
100+ ctx .save_for_backward (inputs , embeddings , offsets_list , resolutions_list , dy_dx )
101+ ctx .dims = [N , num_dim , n_features , n_levels , max_level ]
102+
103+ return outputs
104+
105+ @staticmethod
106+ #@once_differentiable
107+ @custom_bwd
108+ def backward (ctx , grad ):
109+
110+ inputs , embeddings , offsets_list , resolutions_list , dy_dx = ctx .saved_tensors
111+ N , num_dim , n_features , n_levels , max_level = ctx .dims
112+
113+ # grad: [N, n_levels * n_features] --> [n_levels, N, n_features]
114+ grad = grad .view (N , n_levels , n_features ).permute (1 , 0 , 2 ).contiguous ()
115+
116+ # 是梯度的占位变量,和本体的形状相同,因为代码里是直接加原始值的,所以这里得定义为全0
117+ grad_embeddings = torch .zeros_like (embeddings )
118+
119+ if dy_dx is not None :
120+ grad_inputs = torch .zeros_like (inputs , dtype = embeddings .dtype )
121+ else :
122+ grad_inputs = None
123+
124+ _backend .grid_encode_backward (
125+ grad ,
126+ inputs ,
127+ embeddings ,
128+ offsets_list ,
129+ resolutions_list ,
130+ grad_embeddings ,
131+ N , num_dim , n_features , n_levels , max_level ,
132+ dy_dx ,
133+ grad_inputs
134+ )
135+
136+ if dy_dx is not None :
137+ grad_inputs = grad_inputs .to (inputs .dtype )
138+
139+ return grad_inputs , grad_embeddings , None , None , None , None
140+
141+
142+ grid_encode = _grid_encode .apply
143+
144+
145+ class GridEncoder (nn .Module ):
146+ def __init__ (self ,
147+ num_dim = 3 ,
148+ n_features = 2 ,
149+ resolutions_list = (16 , 23 , 32 , 46 , 64 , 92 , 128 , 184 , 256 , 368 , 512 , 736 ),
150+ log2_hashmap_size = 19 ,
151+ ste_binary = False ,
152+ ):
153+ super ().__init__ ()
154+
155+ resolutions_list = torch .tensor (resolutions_list ).to (torch .int )
156+ n_levels = resolutions_list .numel ()
157+
158+ self .num_dim = num_dim # coord dims, 2 or 3
159+ self .n_levels = n_levels # num levels, each level multiply resolution by 2
160+ self .n_features = n_features # encode channels per level
161+ self .log2_hashmap_size = log2_hashmap_size
162+ self .output_dim = n_levels * n_features
163+ self .ste_binary = ste_binary
164+
165+ # allocate parameters
166+ offsets_list = [] # 每层hashtable长度的cumsum
167+ offset = 0 # 用于统计所有层加起来一共需要多少长度的hashtable
168+ self .max_params = 2 ** log2_hashmap_size # 按论文算法的每层的hashtable长度上限
169+ for i in range (n_levels ):
170+ resolution = resolutions_list [i ].item ()
171+ params_in_level = min (self .max_params , resolution ** num_dim ) # limit max number
172+ params_in_level = int (np .ceil (params_in_level / 8 ) * 8 ) # make divisible
173+ offsets_list .append (offset )
174+ offset += params_in_level
175+ offsets_list .append (offset )
176+ offsets_list = torch .from_numpy (np .array (offsets_list , dtype = np .int32 ))
177+ self .register_buffer ('offsets_list' , offsets_list )
178+ self .register_buffer ('resolutions_list' , resolutions_list )
179+
180+ self .n_params = offsets_list [- 1 ] * n_features # 所有的params的数量
181+
182+ # parameters
183+ self .embeddings = nn .Parameter (torch .empty (offset , n_features ))
184+
185+ self .reset_parameters ()
186+
187+ self .n_output_dims = n_levels * n_features
188+
189+ def reset_parameters (self ):
190+ std = 1e-4
191+ self .embeddings .data .uniform_ (- std , std )
192+
193+ def __repr__ (self ):
194+ return f"GridEncoder: num_dim={ self .num_dim } n_levels={ self .n_levels } n_features={ self .n_features } resolution={ self .base_resolution } -> { int (round (self .base_resolution * self .per_level_scale ** (self .n_levels - 1 )))} per_level_scale={ self .per_level_scale :.4f} params={ tuple (self .embeddings .shape )} gridtype={ self .gridtype } align_corners={ self .align_corners } interpolation={ self .interpolation } "
195+
196+ def forward (self , inputs , max_level = None ):
197+ # inputs: [..., num_dim], normalized real world positions in [-1, 1]
198+ # max_level: only calculate first max_level levels (None will use all levels)
199+ # return: [..., n_levels * n_features]
200+
201+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
202+
203+ prefix_shape = list (inputs .shape [:- 1 ])
204+ inputs = inputs .view (- 1 , self .num_dim )
205+
206+ if self .ste_binary :
207+ embeddings = STE_binary .apply (self .embeddings )
208+ else :
209+ embeddings = self .embeddings
210+ outputs = grid_encode (inputs , embeddings , self .offsets_list , self .resolutions_list , inputs .requires_grad , max_level )
211+ outputs = outputs .view (prefix_shape + [self .output_dim ])
212+
213+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
214+
215+ return outputs
0 commit comments