Skip to content

Commit 7eaedb5

Browse files
committed
chore: Add upset_altair
hms-dbmi#4
1 parent 3115b14 commit 7eaedb5

File tree

1 file changed

+368
-0
lines changed

1 file changed

+368
-0
lines changed

altair_upset/upset_altair.py

Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
import altair as alt
2+
import pandas as pd
3+
4+
def visualize(
5+
data=None,
6+
title="",
7+
subtitle="",
8+
sets=None,
9+
abbre=None,
10+
sort_by="frequency",
11+
sort_order="ascending",
12+
width=1200,
13+
height=700,
14+
height_ratio=0.6,
15+
horizontal_bar_chart_width=300,
16+
color_range=["#55A8DB", "#3070B5", "#30363F", "#F1AD60", "#DF6234", "#BDC6CA"],
17+
highlight_color="#EA4667",
18+
glyph_size=200,
19+
set_label_bg_size=1000,
20+
line_connection_size=2,
21+
horizontal_bar_size=20,
22+
vertical_bar_label_size=16,
23+
vertical_bar_padding=20
24+
):
25+
"""
26+
This function generates Altair-based interactive UpSet plots.
27+
28+
Parameters:
29+
- data (pandas.DataFrame): Tabular data containing the membership of each element (row) in
30+
exclusive intersecting sets (column).
31+
- sets (list): List of set names of interest to show in the UpSet plots.
32+
This list reflects the order of sets to be shown in the plots as well.
33+
- abbre (list): Abbreviated set names.
34+
- sort_by (str): "frequency" or "degree"
35+
- sort_order (str): "ascending" or "descending"
36+
- width (int): Vertical size of the UpSet plot.
37+
- height (int): Horizontal size of the UpSet plot.
38+
- height_ratio (float): Ratio of height between upper and under views, ranges from 0 to 1.
39+
- horizontal_bar_chart_width (int): Width of horizontal bar chart on the bottom-right.
40+
- color_range (list): Color to encode sets.
41+
- highlight_color (str): Color to encode intersecting sets upon mouse hover.
42+
- glyph_size (int): Size of UpSet glyph (⬤).
43+
- set_label_bg_size (int): Size of label background in the horizontal bar chart.
44+
- line_connection_size (int): width of lines in matrix view.
45+
- horizontal_bar_size (int): Height of bars in the horizontal bar chart.
46+
- vertical_bar_label_size (int): Font size of texts in the vertical bar chart on the top.
47+
- vertical_bar_padding (int): Gap between a pair of bars in the vertical bar charts.
48+
49+
Return:
50+
Altair `Chart` object.
51+
"""
52+
53+
if (data is None) or (sets is None):
54+
print("No data and/or a list of sets are provided")
55+
return
56+
if (height_ratio < 0) or (1 < height_ratio):
57+
print("height_ratio set to 0.5")
58+
height_ratio = 0.5
59+
if len(sets) != len(abbre):
60+
abbre = None
61+
print("Dropping the `abbre` list because the lengths of `sets` and `abbre` are not identical.")
62+
63+
"""
64+
Data Preprocessing
65+
"""
66+
data["count"] = 0
67+
data = data[sets + ["count"]]
68+
data = data.groupby(sets).count().reset_index()
69+
70+
data["intersection_id"] = data.index
71+
data["degree"] = data[sets].sum(axis=1)
72+
data = data.sort_values(by=["count"], ascending=True if sort_order == "ascending" else False)
73+
74+
data = pd.melt(data, id_vars=[
75+
"intersection_id", "count", "degree"
76+
])
77+
data = data.rename(columns={"variable": "set", "value": "is_intersect"})
78+
79+
if abbre == None:
80+
abbre = sets
81+
82+
set_to_abbre = pd.DataFrame([ [sets[i], abbre[i]] for i in range(len(sets)) ], columns=["set", "set_abbre"])
83+
set_to_order = pd.DataFrame([ [sets[i], 1 + sets.index(sets[i])] for i in range(len(sets)) ], columns=["set", "set_order"])
84+
85+
degree_calculation = ""
86+
for s in sets:
87+
degree_calculation += f"(isDefined(datum['{s}']) ? datum['{s}'] : 0)"
88+
if sets[-1] != s:
89+
degree_calculation += "+"
90+
91+
"""
92+
Selections
93+
"""
94+
legend_selection = alt.selection_multi(fields=["set"], bind="legend")
95+
color_selection = alt.selection_single(fields=["intersection_id"], on="mouseover")
96+
opacity_selection = alt.selection_single(fields=["intersection_id"])
97+
98+
"""
99+
Styles
100+
"""
101+
vertical_bar_chart_height = height * height_ratio
102+
matrix_height = height - vertical_bar_chart_height
103+
matrix_width = width - horizontal_bar_chart_width
104+
105+
vertical_bar_size = min(30, width / len(data["intersection_id"].unique().tolist()) - vertical_bar_padding)
106+
107+
main_color = "#3A3A3A"
108+
brush_opacity = alt.condition(~opacity_selection, alt.value(1), alt.value(0.6))
109+
brush_color = alt.condition(~color_selection, alt.value(main_color), alt.value(highlight_color))
110+
111+
is_show_horizontal_bar_label_bg = len(abbre[0]) <= 2
112+
horizontal_bar_label_bg_color = "white" if is_show_horizontal_bar_label_bg else "black"
113+
114+
x_sort = alt.Sort(
115+
field="count" if sort_by == "frequency" else "degree",
116+
order=sort_order
117+
)
118+
tooltip = [
119+
alt.Tooltip("max(count):Q", title="Cardinality"),
120+
alt.Tooltip("degree:Q", title="Degree")
121+
]
122+
123+
"""
124+
Plots
125+
"""
126+
# To use native interactivity in Altair, we are using the data transformation functions
127+
# supported in Altair.
128+
base = alt.Chart(data).transform_filter(
129+
legend_selection
130+
).transform_pivot(
131+
# Right before this operation, columns should be:
132+
# `count`, `set`, `is_intersect`, (`intersection_id`, `degree`, `set_order`, `set_abbre`)
133+
# where (fields with brackets) should be dropped and recalculated later.
134+
"set",
135+
op="max",
136+
groupby=["intersection_id", "count"],
137+
value="is_intersect"
138+
).transform_aggregate(
139+
# count, set1, set2, ...
140+
count="sum(count)",
141+
groupby=sets
142+
).transform_calculate(
143+
# count, set1, set2, ...
144+
degree=degree_calculation
145+
).transform_filter(
146+
# count, set1, set2, ..., degree
147+
alt.datum["degree"] != 0
148+
).transform_window(
149+
# count, set1, set2, ..., degree
150+
intersection_id="row_number()",
151+
frame=[None, None]
152+
).transform_fold(
153+
# count, set1, set2, ..., degree, intersection_id
154+
sets, as_=["set", "is_intersect"]
155+
).transform_lookup(
156+
# count, set, is_intersect, degree, intersection_id
157+
lookup="set",
158+
from_=alt.LookupData(set_to_abbre, "set", ["set_abbre"])
159+
).transform_lookup(
160+
# count, set, is_intersect, degree, intersection_id, set_abbre
161+
lookup="set",
162+
from_=alt.LookupData(set_to_order, "set", ["set_order"])
163+
).transform_filter(
164+
# Make sure to remove the filtered sets.
165+
legend_selection
166+
).transform_window(
167+
# count, set, is_intersect, degree, intersection_id, set_abbre
168+
set_order="distinct(set)",
169+
frame=[None, 0],
170+
sort=[{"field": "set_order"}]
171+
)
172+
# Now, we have data in the following format:
173+
# count, set, is_intersect, degree, intersection_id, set_abbre
174+
175+
# Cardinality by intersecting sets (vertical bar chart)
176+
vertical_bar = base.mark_bar(color=main_color, size=vertical_bar_size).encode(
177+
x=alt.X(
178+
"intersection_id:N",
179+
axis=alt.Axis(grid=False, labels=False, ticks=False, domain=True),
180+
sort=x_sort,
181+
title=None
182+
),
183+
y=alt.Y(
184+
"max(count):Q",
185+
axis=alt.Axis(grid=False, tickCount=3, orient='right'),
186+
title="Intersection Size"
187+
),
188+
color=brush_color,
189+
tooltip=tooltip
190+
).properties(
191+
width=matrix_width,
192+
height=vertical_bar_chart_height
193+
)
194+
195+
vertical_bar_text = vertical_bar.mark_text(
196+
color=main_color,
197+
dy=-10,
198+
size=vertical_bar_label_size
199+
).encode(
200+
text=alt.Text("count:Q", format=".0f")
201+
)
202+
203+
vertical_bar_chart = (vertical_bar + vertical_bar_text).add_selection(
204+
color_selection
205+
)
206+
207+
# UpSet glyph view (matrix view)
208+
circle_bg = vertical_bar.mark_circle(size=glyph_size, opacity=1).encode(
209+
x=alt.X(
210+
"intersection_id:N",
211+
axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False),
212+
sort=x_sort,
213+
title=None
214+
),
215+
y=alt.Y(
216+
"set_order:N",
217+
axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False),
218+
title=None
219+
),
220+
color=alt.value("#E6E6E6")
221+
).properties(
222+
height=matrix_height
223+
)
224+
225+
rect_bg = circle_bg.mark_rect().transform_filter(
226+
alt.datum["set_order"] % 2 == 1
227+
).encode(
228+
color=alt.value("#F7F7F7")
229+
)
230+
231+
circle = circle_bg.transform_filter(
232+
alt.datum["is_intersect"] == 1
233+
).encode(
234+
color=brush_color
235+
)
236+
237+
line_connection = vertical_bar.mark_bar(size=line_connection_size, color=main_color).transform_filter(
238+
alt.datum["is_intersect"] == 1
239+
).encode(
240+
y=alt.Y("min(set_order):N"),
241+
y2=alt.Y2("max(set_order):N")
242+
)
243+
244+
matrix_view = (circle + rect_bg + circle_bg + line_connection + circle).add_selection(
245+
# Duplicate `circle` is to properly show tooltips.
246+
color_selection
247+
)
248+
249+
# Cardinality by sets (horizontal bar chart)
250+
horizontal_bar_label_bg = base.mark_circle(size=set_label_bg_size).encode(
251+
y=alt.Y(
252+
"set_order:N",
253+
axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False),
254+
title=None,
255+
),
256+
color=alt.Color(
257+
"set:N",
258+
scale=alt.Scale(domain=sets, range=color_range),
259+
title=None
260+
),
261+
opacity=alt.value(1)
262+
)
263+
horizontal_bar_label = horizontal_bar_label_bg.mark_text(
264+
align=("center" if is_show_horizontal_bar_label_bg else "center")
265+
).encode(
266+
text=alt.Text("set_abbre:N"),
267+
color=alt.value(horizontal_bar_label_bg_color)
268+
)
269+
horizontal_bar_axis = (horizontal_bar_label_bg + horizontal_bar_label) if is_show_horizontal_bar_label_bg else horizontal_bar_label
270+
271+
horizontal_bar = horizontal_bar_label_bg.mark_bar(
272+
size=horizontal_bar_size
273+
).transform_filter(
274+
alt.datum["is_intersect"] == 1
275+
).encode(
276+
x=alt.X(
277+
"sum(count):Q",
278+
axis=alt.Axis(grid=False, tickCount=3),
279+
title="Set Size"
280+
)
281+
).properties(
282+
width=horizontal_bar_chart_width
283+
)
284+
285+
# Concat Plots
286+
upsetaltair = alt.vconcat(
287+
vertical_bar_chart,
288+
alt.hconcat(
289+
matrix_view,
290+
horizontal_bar_axis, horizontal_bar, # horizontal bar chart
291+
spacing=5
292+
).resolve_scale(
293+
y="shared"
294+
),
295+
spacing=20
296+
).add_selection(
297+
legend_selection
298+
)
299+
300+
# Apply top-level configuration
301+
upsetaltair = upsetaltair_top_level_configuration(
302+
upsetaltair,
303+
legend_orient="top",
304+
legend_symbol_size=set_label_bg_size / 2.0
305+
).properties(
306+
title={
307+
"text": title,
308+
"subtitle": subtitle,
309+
"fontSize": 20,
310+
"fontWeight": 500,
311+
"subtitleColor": main_color,
312+
"subtitleFontSize": 14
313+
}
314+
)
315+
316+
return upsetaltair
317+
318+
# Top-level altair configuration
319+
def upsetaltair_top_level_configuration(
320+
base,
321+
legend_orient="top-left",
322+
legend_symbol_size=30
323+
):
324+
return base.configure_view(
325+
stroke=None
326+
).configure_title(
327+
fontSize=18,
328+
fontWeight=400,
329+
anchor="start",
330+
subtitlePadding=10
331+
).configure_axis(
332+
labelFontSize=14,
333+
labelFontWeight=300,
334+
titleFontSize=16,
335+
titleFontWeight=400,
336+
titlePadding=10
337+
).configure_legend(
338+
titleFontSize=16,
339+
titleFontWeight=400,
340+
labelFontSize=14,
341+
labelFontWeight=300,
342+
padding=20,
343+
orient=legend_orient,
344+
symbolType="circle",
345+
symbolSize=legend_symbol_size,
346+
).configure_concat(
347+
spacing=0
348+
)
349+
350+
if __name__ == '__main__':
351+
352+
# Use the latest data from https://figshare.com/articles/covid_symptoms_table_csv/12148893
353+
df = pd.read_csv("https://ndownloader.figshare.com/files/22339791")
354+
355+
upset_altair = visualize(
356+
data=df.copy(),
357+
title="Symptoms Reported by Users of the COVID Symptom Tracker App",
358+
subtitle=[
359+
"Story & Data: https://www.nature.com/articles/d41586-020-00154-w",
360+
"Altair-based UpSet Plot: https://github.com/hms-dbmi/upset-altair-notebook"
361+
],
362+
sets=["Shortness of Breath", "Diarrhea", "Fever", "Cough", "Anosmia", "Fatigue"],
363+
abbre=["B", "D", "Fe", "C", "A", "Fa"],
364+
sort_by="frequency",
365+
sort_order="ascending",
366+
)
367+
368+
upset_altair.display()

0 commit comments

Comments
 (0)