@@ -81,6 +81,61 @@ def test_log_dict(self: TensorBoardLoggerTest) -> None:
8181 )
8282 self .assertEqual (tensor_tag .step , 1 )
8383
84+ def test_log_histogram_raw (self : TensorBoardLoggerTest ) -> None :
85+ with tempfile .TemporaryDirectory () as log_dir :
86+ logger = TensorBoardLogger (path = log_dir )
87+
88+ # generate a histogram with 4 bins in the range [0, 1]
89+ data_range = [0.0 , 1.0 ]
90+ bucket_counts = [1 , 3 , 5 , 4 ]
91+ bucket_width = (data_range [1 ] - data_range [0 ]) / len (bucket_counts )
92+ bucket_limits = [
93+ ix * bucket_width + data_range [0 ]
94+ for ix in range (len (bucket_counts ) + 1 )
95+ ]
96+ bucket_centers = [
97+ (lower + upper ) / 2
98+ for lower , upper in zip (bucket_limits [:- 1 ], bucket_limits [1 :])
99+ ]
100+ # sum of the binned values
101+ value_sum = float (
102+ sum (
103+ value * count for value , count in zip (bucket_centers , bucket_counts )
104+ )
105+ )
106+
107+ logger .log_histogram_raw (
108+ "histogram_raw" ,
109+ min = 0 ,
110+ max = 1 ,
111+ num = sum (bucket_counts ),
112+ sum = value_sum ,
113+ sum_squares = value_sum ** 2 ,
114+ bucket_limits = bucket_limits ,
115+ # add an extra leading 0 to match the format of the histogram_raw
116+ bucket_counts = [0 ] + bucket_counts ,
117+ )
118+ logger .close ()
119+
120+ acc = EventAccumulator (log_dir )
121+ acc .Reload ()
122+
123+ # check that the histogram is logged correctly
124+ self .assertIn ("histogram_raw" , acc .Tags ()["histograms" ])
125+ # ensure that we logged exactly one histogram
126+ self .assertEqual (len (acc .Histograms ("histogram_raw" )), 1 )
127+ histogram_event = acc .Histograms ("histogram_raw" )[0 ]
128+ histogram_value = histogram_event .histogram_value
129+ # check that the histogram is logged correctly
130+ self .assertEqual (histogram_value .min , 0 )
131+ self .assertEqual (histogram_value .max , 1 )
132+ self .assertEqual (histogram_value .num , sum (bucket_counts ))
133+ self .assertEqual (histogram_value .sum , value_sum )
134+ self .assertEqual (histogram_value .sum_squares , value_sum ** 2 )
135+ self .assertListEqual (histogram_value .bucket_limit , bucket_limits )
136+ self .assertListEqual (histogram_value .bucket [1 :], bucket_counts )
137+ self .assertEqual (histogram_value .bucket [0 ], 0 )
138+
84139 def test_log_text (self : TensorBoardLoggerTest ) -> None :
85140 with tempfile .TemporaryDirectory () as log_dir :
86141 logger = TensorBoardLogger (path = log_dir )
0 commit comments