Skip to content

Commit 91d813a

Browse files
committed
misc
1 parent 2ffa035 commit 91d813a

File tree

4 files changed

+11
-54
lines changed

4 files changed

+11
-54
lines changed

examples/simple_trainer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,12 @@ def __init__(self):
5858
torch.profiler.ProfilerActivity.CUDA,
5959
]
6060

61-
# 基础配置
62-
self.wait = 1 # 开始记录前等待的步数
63-
self.warmup = 2 # 预热步数
64-
self.active = 30_000 # 实际分析的步数
65-
# self.repeat = 2 # 重复次数
66-
# self.skip_first = 10 # 跳过前N步(可选)
61+
self.wait = 1
62+
self.warmup = 2
63+
self.active = 30_000
6764

68-
# 创建schedule
6965
self.schedule = self._create_schedule()
7066

71-
# 其他profiler设置
7267
self.on_trace_ready = torch.profiler.tensorboard_trace_handler('./log/profiler')
7368
self.record_shapes = True
7469
self.profile_memory = True
@@ -79,8 +74,6 @@ def _create_schedule(self):
7974
wait=self.wait,
8075
warmup=self.warmup,
8176
active=self.active,
82-
# repeat=self.repeat,
83-
# skip_first=self.skip_first
8477
)
8578

8679
def update_schedule(self, **kwargs):

gsplat/compression/outlier_filter.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,28 @@
66
from torch import Tensor
77

88
def filter_splats(splats: Dict[str, Tensor], opa_thres: float=0.005, std_factor: float=2.0, k_neighbors: int=10):
9-
# 1. 首先基于不透明度过滤
109
opacity_mask = torch.sigmoid(splats["opacities"]) >= opa_thres
1110

12-
# # 2. 使用KD树计算每个点到其k个最近邻的平均距离
11+
1312
# pos = splats["means"].cpu().numpy()
1413
# kdtree = KDTree(pos)
1514
# distances, _ = kdtree.query(pos, k=k_neighbors)
1615
# mean_distances = np.mean(distances, axis=1)
1716

18-
# # 3. 计算平均距离的统计特征
17+
#
1918
# dist_mean = np.mean(mean_distances)
2019
# dist_std = np.std(mean_distances)
2120

22-
# # 4. 基于距离判定离群点
2321
# distance_mask = mean_distances <= (dist_mean + std_factor * dist_std)
2422
# distance_mask = torch.from_numpy(distance_mask).to(opacity_mask.device)
2523

26-
# # 5. 组合两个过滤条件
27-
# ## v1: 保守方案
24+
# ## v1:
2825
# # outlier = torch.logical_and(~opacity_mask, ~distance_mask) # 既在位置上离群,且不透明度又低
2926
# # valid_mask = ~outlier
30-
# ## v2: 激进方案
27+
# ## v2:
3128
# # valid_mask = torch.logical_and(opacity_mask, distance_mask)
3229
valid_mask = opacity_mask
3330

34-
# 6. 逐一过滤元素
3531
for n, v in splats.items():
3632
splats[n] = v[valid_mask]
3733

gsplat/compression_simulation/ada_mask.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,21 @@ def get_temperature(self, current_step):
2020
if current_step < self.annealing_start_iter:
2121
return self.start_temp
2222

23-
# 计算退火后的温度
2423
progress = (current_step - self.annealing_start_iter) / \
2524
(self.total_iters - self.annealing_start_iter)
26-
progress = min(max(progress, 0), 1) # 限制在[0,1]范围内
25+
progress = min(max(progress, 0), 1)
2726

28-
# 使用指数衰减
2927
temperature = self.start_temp * math.exp(
3028
math.log(self.end_temp / self.start_temp) * progress
3129
)
3230
return temperature
3331

3432
def forward(self, x, current_step):
3533
if self.training:
36-
# 训练时使用当前温度的sigmoid
3734
self.current_iter = current_step
3835
temperature = self.get_temperature(current_step)
3936
mask = torch.sigmoid(self.mask_logits / temperature)
4037
else:
41-
# 测试时使用硬阈值
4238
mask = (torch.sigmoid(self.mask_logits) >= 0.5).float()
4339
return x * mask
4440

@@ -47,21 +43,13 @@ def get_binary_mask(self):
4743
return (torch.sigmoid(self.mask_logits) >= 0.5).float()
4844

4945
def get_sparsity_loss(self, lambda_l1=0.01, lambda_target=0.1):
50-
"""
51-
计算稀疏性损失,包括:
52-
1. L1正则化
53-
2. 目标稀疏度的KL散度损失
54-
"""
5546
temperature = self.get_temperature(self.current_iter)
5647
mask = torch.sigmoid(self.mask_logits / temperature)
5748

58-
# L1 正则化
5949
l1_loss = lambda_l1 * torch.mean(mask)
6050

61-
# 计算当前的平均激活率
6251
current_sparsity = torch.mean(mask)
6352

64-
# KL散度损失,确保整体稀疏度接近目标值
6553
target = torch.tensor(self.target_sparsity).to(mask.device)
6654
kl_loss = lambda_target * F.binary_cross_entropy(current_sparsity, target)
6755

gsplat/compression_simulation/entropy_model.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,9 @@ def forward(self, x, Q=None, **kwargs):
161161
else:
162162
Q = torch.tensor([Q], device=x.device)
163163

164-
# [N, C] -> [C, 1, N], batch维度移到最后以提高内存访问效率
164+
# [N, C] -> [C, 1, N]
165165
x = x.t().unsqueeze(1)
166166

167-
# 预计算公共部分
168167
half_Q = 0.5 * Q.detach()
169168
x_lower = x - half_Q
170169
x_upper = x + half_Q
@@ -209,10 +208,9 @@ def forward(self, x, Q=None, **kwargs):
209208
else:
210209
Q = torch.tensor([Q], device=x.device)
211210

212-
# [N, C] -> [C, 1, N], batch维度移到最后以提高内存访问效率
211+
# [N, C] -> [C, 1, N]
213212
x = x.t().unsqueeze(1)
214213

215-
# 预计算公共部分
216214
half_Q = 0.5 * Q.detach()
217215
x_lower = x - half_Q
218216
x_upper = x + half_Q
@@ -265,10 +263,9 @@ def get_likelihood(self, x, Q=None, **kwargs):
265263
else:
266264
Q = torch.tensor([Q], device=x.device)
267265

268-
# [N, C] -> [C, 1, N], batch维度移到最后以提高内存访问效率
266+
# [N, C] -> [C, 1, N],
269267
x = x.t().unsqueeze(1)
270268

271-
# 预计算公共部分
272269
half_Q = 0.5 * Q.detach()
273270
x_lower = x - half_Q
274271
x_upper = x + half_Q
@@ -346,23 +343,6 @@ def get_means_and_scales(self, pos):
346343
scales = torch.clamp(scales, min=1e-9)
347344

348345
return means, scales
349-
350-
# class Low_bound(torch.autograd.Function):
351-
# @staticmethod
352-
# def forward(ctx, x):
353-
# ctx.save_for_backward(x)
354-
# x = torch.clamp(x, min=1e-6)
355-
# return x
356-
357-
# @staticmethod
358-
# def backward(ctx, g):
359-
# x, = ctx.saved_tensors
360-
# grad1 = g.clone()
361-
# grad1[x < 1e-6] = 0
362-
# pass_through_if = np.logical_or(
363-
# x.cpu().numpy() >= 1e-6, g.cpu().numpy() < 0.0)
364-
# t = torch.Tensor(pass_through_if+0.0).cuda()
365-
# return grad1 * t
366346

367347
def lower_bound_fwd(x: Tensor, bound: Tensor) -> Tensor:
368348
return torch.max(x, bound)

0 commit comments

Comments
 (0)