forked from graphdeco-inria/gaussian-splatting
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGaussianScoreTracker.py
More file actions
51 lines (42 loc) · 1.52 KB
/
GaussianScoreTracker.py
File metadata and controls
51 lines (42 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
class GaussianScoreTracker:
def __init__(self):
"""
Tracks Gaussian scores using a simple running average.
"""
self.running_sum = None
self.count = 0
self.gaussian_count = 0
def update(self, new_scores):
"""
Updates the running average of Gaussian scores.
Args:
new_scores (torch.Tensor): Tensor of current frame's Gaussian scores.
"""
with torch.no_grad():
new_scores = new_scores.detach() # Detach incoming scores
new_scores = torch.clamp(new_scores, max=100)
num_gaussians = new_scores.shape[0]
if self.running_sum is None or num_gaussians != self.gaussian_count:
self.reset(num_gaussians)
self.running_sum += new_scores
self.count += 1
def reset(self, num_gaussians):
"""
Resets the tracker if the number of Gaussians changes.
Args:
num_gaussians (int): The new number of Gaussians.
"""
print("Resetting Gaussian Score Tracker")
self.running_sum = torch.zeros(num_gaussians, device="cuda", dtype=torch.float32)
self.count = 0
self.gaussian_count = num_gaussians
def get_scores(self):
"""
Returns the current average scores.
Returns:
torch.Tensor or None: Averaged Gaussian scores.
"""
if self.running_sum is None or self.count == 0:
return None
return self.running_sum / self.count