Skip to content

Commit 63db6f6

Browse files
committed
misc: update params distribution visualization function
1 parent bb04f83 commit 63db6f6

File tree

1 file changed

+45
-19
lines changed

1 file changed

+45
-19
lines changed

examples/simple_trainer.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,12 +1187,14 @@ def train(self):
11871187

11881188
# eval the full set
11891189
if step in [i - 1 for i in cfg.eval_steps]:
1190+
self.run_param_distribution_vis(self.comp_sim_splats,
1191+
f"{cfg.result_dir}/visualization/comp_sim_step{step}")
11901192
self.eval(step)
11911193
self.render_traj(step)
11921194

11931195
# run compression
1194-
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
1195-
self.run_compression(step=step)
1196+
# if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
1197+
# self.run_compression(step=step)
11961198

11971199
if not cfg.disable_viewer:
11981200
self.viewer.lock.release()
@@ -1399,32 +1401,56 @@ def run_compression(self, step: int):
13991401
@torch.no_grad()
14001402
def run_param_distribution_vis(self, param_dict: Dict[str, Tensor], save_dir: str):
14011403
import matplotlib.pyplot as plt
1404+
import matplotlib.ticker as ticker
1405+
from matplotlib.colors import LinearSegmentedColormap
14021406

1403-
os.makedirs(save_dir, exist_ok=True)
1404-
for param_name, value in param_dict.items():
1405-
1407+
def plot_distribution(value, param_name, save_dir):
14061408
tensor_np = value.flatten().detach().cpu().numpy()
14071409
min_val, max_val = tensor_np.min(), tensor_np.max()
1408-
1409-
plt.figure(figsize=(6, 4))
1410-
n, bins, patches = plt.hist(tensor_np, bins=50, density=False, alpha=0.7, color='b')
1411-
1412-
for count, bin_edge in zip(n, bins):
1413-
plt.text(bin_edge, count, f'{int(count)}', fontsize=8, va='bottom', ha='center')
1414-
1410+
1411+
nice_blue = '#4878CF' # Brighter blue
1412+
1413+
plt.figure(figsize=(6, 4.5), dpi=100)
1414+
1415+
# Use more bins for a smoother histogram
1416+
n, bins, patches = plt.hist(tensor_np, bins=50, density=False, alpha=0.85,
1417+
color=nice_blue, edgecolor='none')
1418+
1419+
# Add grid lines but place them behind the chart
1420+
plt.grid(alpha=0.3, linestyle='--', axis='y')
1421+
plt.gca().set_axisbelow(True)
1422+
1423+
# Use scientific notation for y-axis ticks
1424+
plt.gca().yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
1425+
plt.gca().ticklabel_format(style='sci', axis='y', scilimits=(0,0))
1426+
1427+
# Improved annotations for minimum and maximum values, smaller size
14151428
plt.annotate(f'Min: {min_val:.2f}', xy=(min_val, 0), xytext=(min_val, max(n) * 0.1),
1416-
arrowprops=dict(facecolor='green', shrink=0.05), fontsize=10, color='green')
1429+
arrowprops=dict(facecolor='green', width=1.5, headwidth=6, headlength=6, shrink=0.05),
1430+
fontsize=8, color='darkgreen', weight='bold',
1431+
bbox=dict(boxstyle="round,pad=0.1", fc="white", ec="green", alpha=0.7))
14171432

14181433
plt.annotate(f'Max: {max_val:.2f}', xy=(max_val, 0), xytext=(max_val, max(n) * 0.1),
1419-
arrowprops=dict(facecolor='red', shrink=0.05), fontsize=10, color='red')
1420-
1434+
arrowprops=dict(facecolor='red', width=1.5, headwidth=6, headlength=6, shrink=0.05),
1435+
fontsize=8, color='darkred', weight='bold',
1436+
bbox=dict(boxstyle="round,pad=0.1", fc="white", ec="red", alpha=0.7))
1437+
1438+
# Beautify title and labels
14211439
plt.title(f'{param_name} Distribution')
14221440
plt.xlabel('Value')
1423-
plt.ylabel('Density')
1424-
1425-
plt.savefig(os.path.join(save_dir, f'{param_name}.png'))
1426-
1441+
plt.ylabel('Frequency')
1442+
1443+
# Adjust x and y axis ranges to leave enough space for annotations
1444+
plt.xlim(min_val - (max_val - min_val) * 0.05, max_val + (max_val - min_val) * 0.05)
1445+
plt.ylim(0, max(n) * 1.2)
1446+
1447+
plt.tight_layout()
1448+
plt.savefig(os.path.join(save_dir, f'{param_name}.png'), dpi=120, bbox_inches='tight')
14271449
plt.close()
1450+
1451+
os.makedirs(save_dir, exist_ok=True)
1452+
for param_name, value in param_dict.items():
1453+
plot_distribution(value, param_name, save_dir)
14281454

14291455
print(f"Histograms saved in '{save_dir}' directory.")
14301456

0 commit comments

Comments
 (0)