Skip to content

Commit 8a69041

Browse files
committed
Added HistogramByStep panel
1 parent 4ae4a53 commit 8a69041

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from comet_ml import API, ui
2+
from comet_ml.data_structure import Histogram
3+
import random
4+
import numpy as np
5+
import plotly.express as px
6+
7+
# Options:
8+
9+
max_xbins = st.sidebar.slider(
10+
"Maximum X bins", # Label for the slider
11+
min_value=5, # Minimum allowed value
12+
max_value=20, # Maximum allowed value
13+
value=10, # Default starting value
14+
step=1 # Increment step for the slider
15+
)
16+
max_ybins = st.sidebar.slider(
17+
"Maximum Y bins", # Label for the slider
18+
min_value=5, # Minimum allowed value
19+
max_value=100, # Maximum allowed value
20+
value=50, # Default starting value
21+
step=1 # Increment step for the slider
22+
)
23+
start = None
24+
stop = None
25+
#max_ybins = 50
26+
#max_xbins = 50
27+
# Colors are scaled from highest to lowest. You can add
28+
# additional values between 0 and 1 to add color ranges.
29+
colorScale = [
30+
[0, "white"], # lower values
31+
[0.5, "gray"], # middle value
32+
[1, "blue"] # higher values
33+
]
34+
showScale = False
35+
layout = {
36+
"title": "Histograms by Step",
37+
"xaxis": {
38+
"ticks": "",
39+
"side": "bottom",
40+
"title": "Steps"
41+
},
42+
"yaxis": {
43+
"ticks": "",
44+
"ticksuffix": " ",
45+
"autosize": True,
46+
"title": "Weights"
47+
}
48+
}
49+
50+
def collect(histogram, start=None, stop=None, bins=50):
51+
"""
52+
Collect the counts for the given range into bins.
53+
54+
Args:
55+
start: optional, float, start of range to display
56+
stop: optional, float, end of range to display
57+
bins: optional, int, number of bins
58+
59+
Returns a list of dicts containing details on each
60+
virtual bin.
61+
"""
62+
counts_compressed = histogram.counts_compressed()
63+
if start is None:
64+
if len(counts_compressed) > 0:
65+
start = histogram.values[counts_compressed[0][0]]
66+
else:
67+
start = -1.0
68+
if stop is None:
69+
if len(counts_compressed) > 1:
70+
stop = histogram.values[counts_compressed[-1][0]]
71+
else:
72+
stop = 1.0
73+
74+
step = (stop - start) / bins
75+
76+
counts = histogram.get_counts(start, stop + step, step)
77+
current = start
78+
bins = []
79+
next_one = current + step
80+
i = 0
81+
while next_one <= stop + step and i < len(counts):
82+
start_bin = histogram.get_bin_index(current)
83+
stop_bin = histogram.get_bin_index(next_one)
84+
current_bin = {
85+
"value_start": current,
86+
"value_stop": next_one,
87+
"bin_index_start": start_bin,
88+
"bin_index_stop": stop_bin,
89+
"count": counts[i],
90+
}
91+
bins.append(current_bin)
92+
current = next_one
93+
next_one = current + step
94+
i += 1
95+
return bins
96+
97+
def get_histogram_indices(length, max_xbins):
98+
"""
99+
Get indices from list of histograms, sampling if necessary.
100+
"""
101+
if (length > max_xbins):
102+
return (
103+
[0] +
104+
random.sample(
105+
list(range(1, length - 1)),
106+
max_xbins - 2) +
107+
[length - 1])
108+
else:
109+
return list(range(length))
110+
111+
def get_histogram_data(
112+
experiment,
113+
asset,
114+
):
115+
assetJSON = experiment.get_asset(asset["assetId"], return_type="json")
116+
histograms = []
117+
# {"histograms": [{"step": num, "histogram": {"index_values"}}, ...]
118+
selected = get_histogram_indices(len(assetJSON["histograms"]), max_xbins)
119+
for index in selected:
120+
hist = assetJSON["histograms"][index]
121+
histogram = Histogram.from_json(hist["histogram"])
122+
histogram.logged_at_step = hist["step"]
123+
histograms.append(histogram)
124+
125+
# First, find the overall min/max of all histograms:
126+
min_val, max_val = float("+inf"), float("-inf")
127+
for histogram in histograms:
128+
# {"value_start", "value_stop", "bin_index_start", "bin_index_stop", "count"}
129+
data = collect(histogram, start=None, stop=None, bins=max_ybins)
130+
min_val = min(min_val, data[0]["value_start"])
131+
max_val = max(max_val, data[-1]["value_stop"])
132+
133+
zValues = []
134+
yValues = []
135+
span = (max_val - min_val)/(max_ybins)
136+
xValues = np.arange(min_val, max_val + span, span)
137+
for h, histogram in enumerate(histograms):
138+
# {"value_start", "value_stop", "bin_index_start", "bin_index_stop", "count"}
139+
data = collect(histogram, start=min_val, stop=max_val, bins=max_ybins)
140+
zValues.append([bin["count"] for bin in data])
141+
yValues.append(histogram.logged_at_step)
142+
return [xValues, yValues, np.transpose(zValues), min_val, max_val]
143+
144+
def plot_histogram(experiment, asset):
145+
x, y, data, xmin, xmax = get_histogram_data(
146+
experiment,
147+
asset,
148+
)
149+
if (len(data[0]) == 0):
150+
print("<h1>No histogram data available<h1>")
151+
return None
152+
153+
# Transposed, so x is y, y is x:
154+
fig = px.imshow(
155+
data,
156+
x=y,
157+
y=x[:len(data)],
158+
aspect="auto",
159+
color_continuous_scale=colorScale,
160+
)
161+
st.plotly_chart(fig)
162+
163+
164+
api = API()
165+
experiments = api.get_panel_experiments()
166+
if len(experiments) == 1:
167+
selected_experiment = experiments[0]
168+
else:
169+
selected_experiment = ui.dropdown("Experiments: ", experiments)
170+
171+
if selected_experiment:
172+
assets = sorted(
173+
selected_experiment.get_asset_list('histogram_combined_3d'),
174+
key=lambda item: item["fileName"]
175+
)
176+
selected_histogram = ui.dropdown(
177+
"Histogram: ",
178+
assets,
179+
format_func=lambda item: item["fileName"]
180+
)
181+
if selected_histogram:
182+
plot_histogram(selected_experiment, selected_histogram)
183+
else:
184+
print("No histograms available.")
185+
else:
186+
print("No experiments available.")

0 commit comments

Comments
 (0)