@@ -1187,12 +1187,14 @@ def train(self):
11871187
11881188 # eval the full set
11891189 if step in [i - 1 for i in cfg .eval_steps ]:
1190+ self .run_param_distribution_vis (self .comp_sim_splats ,
1191+ f"{ cfg .result_dir } /visualization/comp_sim_step{ step } " )
11901192 self .eval (step )
11911193 self .render_traj (step )
11921194
11931195 # run compression
1194- if cfg .compression is not None and step in [i - 1 for i in cfg .eval_steps ]:
1195- self .run_compression (step = step )
1196+ # if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
1197+ # self.run_compression(step=step)
11961198
11971199 if not cfg .disable_viewer :
11981200 self .viewer .lock .release ()
@@ -1399,32 +1401,56 @@ def run_compression(self, step: int):
13991401 @torch .no_grad ()
14001402 def run_param_distribution_vis (self , param_dict : Dict [str , Tensor ], save_dir : str ):
14011403 import matplotlib .pyplot as plt
1404+ import matplotlib .ticker as ticker
1405+ from matplotlib .colors import LinearSegmentedColormap
14021406
1403- os .makedirs (save_dir , exist_ok = True )
1404- for param_name , value in param_dict .items ():
1405-
1407+ def plot_distribution (value , param_name , save_dir ):
14061408 tensor_np = value .flatten ().detach ().cpu ().numpy ()
14071409 min_val , max_val = tensor_np .min (), tensor_np .max ()
1408-
1409- plt .figure (figsize = (6 , 4 ))
1410- n , bins , patches = plt .hist (tensor_np , bins = 50 , density = False , alpha = 0.7 , color = 'b' )
1411-
1412- for count , bin_edge in zip (n , bins ):
1413- plt .text (bin_edge , count , f'{ int (count )} ' , fontsize = 8 , va = 'bottom' , ha = 'center' )
1414-
1410+
1411+ nice_blue = '#4878CF' # Brighter blue
1412+
1413+ plt .figure (figsize = (6 , 4.5 ), dpi = 100 )
1414+
1415+ # Use more bins for a smoother histogram
1416+ n , bins , patches = plt .hist (tensor_np , bins = 50 , density = False , alpha = 0.85 ,
1417+ color = nice_blue , edgecolor = 'none' )
1418+
1419+ # Add grid lines but place them behind the chart
1420+ plt .grid (alpha = 0.3 , linestyle = '--' , axis = 'y' )
1421+ plt .gca ().set_axisbelow (True )
1422+
1423+ # Use scientific notation for y-axis ticks
1424+ plt .gca ().yaxis .set_major_formatter (ticker .ScalarFormatter (useMathText = True ))
1425+ plt .gca ().ticklabel_format (style = 'sci' , axis = 'y' , scilimits = (0 ,0 ))
1426+
1427+ # Improved annotations for minimum and maximum values, smaller size
14151428 plt .annotate (f'Min: { min_val :.2f} ' , xy = (min_val , 0 ), xytext = (min_val , max (n ) * 0.1 ),
1416- arrowprops = dict (facecolor = 'green' , shrink = 0.05 ), fontsize = 10 , color = 'green' )
1429+ arrowprops = dict (facecolor = 'green' , width = 1.5 , headwidth = 6 , headlength = 6 , shrink = 0.05 ),
1430+ fontsize = 8 , color = 'darkgreen' , weight = 'bold' ,
1431+ bbox = dict (boxstyle = "round,pad=0.1" , fc = "white" , ec = "green" , alpha = 0.7 ))
14171432
14181433 plt .annotate (f'Max: { max_val :.2f} ' , xy = (max_val , 0 ), xytext = (max_val , max (n ) * 0.1 ),
1419- arrowprops = dict (facecolor = 'red' , shrink = 0.05 ), fontsize = 10 , color = 'red' )
1420-
1434+ arrowprops = dict (facecolor = 'red' , width = 1.5 , headwidth = 6 , headlength = 6 , shrink = 0.05 ),
1435+ fontsize = 8 , color = 'darkred' , weight = 'bold' ,
1436+ bbox = dict (boxstyle = "round,pad=0.1" , fc = "white" , ec = "red" , alpha = 0.7 ))
1437+
1438+ # Beautify title and labels
14211439 plt .title (f'{ param_name } Distribution' )
14221440 plt .xlabel ('Value' )
1423- plt .ylabel ('Density' )
1424-
1425- plt .savefig (os .path .join (save_dir , f'{ param_name } .png' ))
1426-
1441+ plt .ylabel ('Frequency' )
1442+
1443+ # Adjust x and y axis ranges to leave enough space for annotations
1444+ plt .xlim (min_val - (max_val - min_val ) * 0.05 , max_val + (max_val - min_val ) * 0.05 )
1445+ plt .ylim (0 , max (n ) * 1.2 )
1446+
1447+ plt .tight_layout ()
1448+ plt .savefig (os .path .join (save_dir , f'{ param_name } .png' ), dpi = 120 , bbox_inches = 'tight' )
14271449 plt .close ()
1450+
1451+ os .makedirs (save_dir , exist_ok = True )
1452+ for param_name , value in param_dict .items ():
1453+ plot_distribution (value , param_name , save_dir )
14281454
14291455 print (f"Histograms saved in '{ save_dir } ' directory." )
14301456
0 commit comments