|
| 1 | +import altair as alt |
| 2 | +import pandas as pd |
| 3 | + |
| 4 | +def visualize( |
| 5 | + data=None, |
| 6 | + title="", |
| 7 | + subtitle="", |
| 8 | + sets=None, |
| 9 | + abbre=None, |
| 10 | + sort_by="frequency", |
| 11 | + sort_order="ascending", |
| 12 | + width=1200, |
| 13 | + height=700, |
| 14 | + height_ratio=0.6, |
| 15 | + horizontal_bar_chart_width=300, |
| 16 | + color_range=["#55A8DB", "#3070B5", "#30363F", "#F1AD60", "#DF6234", "#BDC6CA"], |
| 17 | + highlight_color="#EA4667", |
| 18 | + glyph_size=200, |
| 19 | + set_label_bg_size=1000, |
| 20 | + line_connection_size=2, |
| 21 | + horizontal_bar_size=20, |
| 22 | + vertical_bar_label_size=16, |
| 23 | + vertical_bar_padding=20 |
| 24 | +): |
| 25 | + """ |
| 26 | + This function generates Altair-based interactive UpSet plots. |
| 27 | +
|
| 28 | + Parameters: |
| 29 | + - data (pandas.DataFrame): Tabular data containing the membership of each element (row) in |
| 30 | + exclusive intersecting sets (column). |
| 31 | + - sets (list): List of set names of interest to show in the UpSet plots. |
| 32 | + This list reflects the order of sets to be shown in the plots as well. |
| 33 | + - abbre (list): Abbreviated set names. |
| 34 | + - sort_by (str): "frequency" or "degree" |
| 35 | + - sort_order (str): "ascending" or "descending" |
| 36 | + - width (int): Vertical size of the UpSet plot. |
| 37 | + - height (int): Horizontal size of the UpSet plot. |
| 38 | + - height_ratio (float): Ratio of height between upper and under views, ranges from 0 to 1. |
| 39 | + - horizontal_bar_chart_width (int): Width of horizontal bar chart on the bottom-right. |
| 40 | + - color_range (list): Color to encode sets. |
| 41 | + - highlight_color (str): Color to encode intersecting sets upon mouse hover. |
| 42 | + - glyph_size (int): Size of UpSet glyph (⬤). |
| 43 | + - set_label_bg_size (int): Size of label background in the horizontal bar chart. |
| 44 | + - line_connection_size (int): width of lines in matrix view. |
| 45 | + - horizontal_bar_size (int): Height of bars in the horizontal bar chart. |
| 46 | + - vertical_bar_label_size (int): Font size of texts in the vertical bar chart on the top. |
| 47 | + - vertical_bar_padding (int): Gap between a pair of bars in the vertical bar charts. |
| 48 | +
|
| 49 | + Return: |
| 50 | + Altair `Chart` object. |
| 51 | + """ |
| 52 | + |
| 53 | + if (data is None) or (sets is None): |
| 54 | + print("No data and/or a list of sets are provided") |
| 55 | + return |
| 56 | + if (height_ratio < 0) or (1 < height_ratio): |
| 57 | + print("height_ratio set to 0.5") |
| 58 | + height_ratio = 0.5 |
| 59 | + if len(sets) != len(abbre): |
| 60 | + abbre = None |
| 61 | + print("Dropping the `abbre` list because the lengths of `sets` and `abbre` are not identical.") |
| 62 | + |
| 63 | + """ |
| 64 | + Data Preprocessing |
| 65 | + """ |
| 66 | + data["count"] = 0 |
| 67 | + data = data[sets + ["count"]] |
| 68 | + data = data.groupby(sets).count().reset_index() |
| 69 | + |
| 70 | + data["intersection_id"] = data.index |
| 71 | + data["degree"] = data[sets].sum(axis=1) |
| 72 | + data = data.sort_values(by=["count"], ascending=True if sort_order == "ascending" else False) |
| 73 | + |
| 74 | + data = pd.melt(data, id_vars=[ |
| 75 | + "intersection_id", "count", "degree" |
| 76 | + ]) |
| 77 | + data = data.rename(columns={"variable": "set", "value": "is_intersect"}) |
| 78 | + |
| 79 | + if abbre == None: |
| 80 | + abbre = sets |
| 81 | + |
| 82 | + set_to_abbre = pd.DataFrame([ [sets[i], abbre[i]] for i in range(len(sets)) ], columns=["set", "set_abbre"]) |
| 83 | + set_to_order = pd.DataFrame([ [sets[i], 1 + sets.index(sets[i])] for i in range(len(sets)) ], columns=["set", "set_order"]) |
| 84 | + |
| 85 | + degree_calculation = "" |
| 86 | + for s in sets: |
| 87 | + degree_calculation += f"(isDefined(datum['{s}']) ? datum['{s}'] : 0)" |
| 88 | + if sets[-1] != s: |
| 89 | + degree_calculation += "+" |
| 90 | + |
| 91 | + """ |
| 92 | + Selections |
| 93 | + """ |
| 94 | + legend_selection = alt.selection_multi(fields=["set"], bind="legend") |
| 95 | + color_selection = alt.selection_single(fields=["intersection_id"], on="mouseover") |
| 96 | + opacity_selection = alt.selection_single(fields=["intersection_id"]) |
| 97 | + |
| 98 | + """ |
| 99 | + Styles |
| 100 | + """ |
| 101 | + vertical_bar_chart_height = height * height_ratio |
| 102 | + matrix_height = height - vertical_bar_chart_height |
| 103 | + matrix_width = width - horizontal_bar_chart_width |
| 104 | + |
| 105 | + vertical_bar_size = min(30, width / len(data["intersection_id"].unique().tolist()) - vertical_bar_padding) |
| 106 | + |
| 107 | + main_color = "#3A3A3A" |
| 108 | + brush_opacity = alt.condition(~opacity_selection, alt.value(1), alt.value(0.6)) |
| 109 | + brush_color = alt.condition(~color_selection, alt.value(main_color), alt.value(highlight_color)) |
| 110 | + |
| 111 | + is_show_horizontal_bar_label_bg = len(abbre[0]) <= 2 |
| 112 | + horizontal_bar_label_bg_color = "white" if is_show_horizontal_bar_label_bg else "black" |
| 113 | + |
| 114 | + x_sort = alt.Sort( |
| 115 | + field="count" if sort_by == "frequency" else "degree", |
| 116 | + order=sort_order |
| 117 | + ) |
| 118 | + tooltip = [ |
| 119 | + alt.Tooltip("max(count):Q", title="Cardinality"), |
| 120 | + alt.Tooltip("degree:Q", title="Degree") |
| 121 | + ] |
| 122 | + |
| 123 | + """ |
| 124 | + Plots |
| 125 | + """ |
| 126 | + # To use native interactivity in Altair, we are using the data transformation functions |
| 127 | + # supported in Altair. |
| 128 | + base = alt.Chart(data).transform_filter( |
| 129 | + legend_selection |
| 130 | + ).transform_pivot( |
| 131 | + # Right before this operation, columns should be: |
| 132 | + # `count`, `set`, `is_intersect`, (`intersection_id`, `degree`, `set_order`, `set_abbre`) |
| 133 | + # where (fields with brackets) should be dropped and recalculated later. |
| 134 | + "set", |
| 135 | + op="max", |
| 136 | + groupby=["intersection_id", "count"], |
| 137 | + value="is_intersect" |
| 138 | + ).transform_aggregate( |
| 139 | + # count, set1, set2, ... |
| 140 | + count="sum(count)", |
| 141 | + groupby=sets |
| 142 | + ).transform_calculate( |
| 143 | + # count, set1, set2, ... |
| 144 | + degree=degree_calculation |
| 145 | + ).transform_filter( |
| 146 | + # count, set1, set2, ..., degree |
| 147 | + alt.datum["degree"] != 0 |
| 148 | + ).transform_window( |
| 149 | + # count, set1, set2, ..., degree |
| 150 | + intersection_id="row_number()", |
| 151 | + frame=[None, None] |
| 152 | + ).transform_fold( |
| 153 | + # count, set1, set2, ..., degree, intersection_id |
| 154 | + sets, as_=["set", "is_intersect"] |
| 155 | + ).transform_lookup( |
| 156 | + # count, set, is_intersect, degree, intersection_id |
| 157 | + lookup="set", |
| 158 | + from_=alt.LookupData(set_to_abbre, "set", ["set_abbre"]) |
| 159 | + ).transform_lookup( |
| 160 | + # count, set, is_intersect, degree, intersection_id, set_abbre |
| 161 | + lookup="set", |
| 162 | + from_=alt.LookupData(set_to_order, "set", ["set_order"]) |
| 163 | + ).transform_filter( |
| 164 | + # Make sure to remove the filtered sets. |
| 165 | + legend_selection |
| 166 | + ).transform_window( |
| 167 | + # count, set, is_intersect, degree, intersection_id, set_abbre |
| 168 | + set_order="distinct(set)", |
| 169 | + frame=[None, 0], |
| 170 | + sort=[{"field": "set_order"}] |
| 171 | + ) |
| 172 | + # Now, we have data in the following format: |
| 173 | + # count, set, is_intersect, degree, intersection_id, set_abbre |
| 174 | + |
| 175 | + # Cardinality by intersecting sets (vertical bar chart) |
| 176 | + vertical_bar = base.mark_bar(color=main_color, size=vertical_bar_size).encode( |
| 177 | + x=alt.X( |
| 178 | + "intersection_id:N", |
| 179 | + axis=alt.Axis(grid=False, labels=False, ticks=False, domain=True), |
| 180 | + sort=x_sort, |
| 181 | + title=None |
| 182 | + ), |
| 183 | + y=alt.Y( |
| 184 | + "max(count):Q", |
| 185 | + axis=alt.Axis(grid=False, tickCount=3, orient='right'), |
| 186 | + title="Intersection Size" |
| 187 | + ), |
| 188 | + color=brush_color, |
| 189 | + tooltip=tooltip |
| 190 | + ).properties( |
| 191 | + width=matrix_width, |
| 192 | + height=vertical_bar_chart_height |
| 193 | + ) |
| 194 | + |
| 195 | + vertical_bar_text = vertical_bar.mark_text( |
| 196 | + color=main_color, |
| 197 | + dy=-10, |
| 198 | + size=vertical_bar_label_size |
| 199 | + ).encode( |
| 200 | + text=alt.Text("count:Q", format=".0f") |
| 201 | + ) |
| 202 | + |
| 203 | + vertical_bar_chart = (vertical_bar + vertical_bar_text).add_selection( |
| 204 | + color_selection |
| 205 | + ) |
| 206 | + |
| 207 | + # UpSet glyph view (matrix view) |
| 208 | + circle_bg = vertical_bar.mark_circle(size=glyph_size, opacity=1).encode( |
| 209 | + x=alt.X( |
| 210 | + "intersection_id:N", |
| 211 | + axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False), |
| 212 | + sort=x_sort, |
| 213 | + title=None |
| 214 | + ), |
| 215 | + y=alt.Y( |
| 216 | + "set_order:N", |
| 217 | + axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False), |
| 218 | + title=None |
| 219 | + ), |
| 220 | + color=alt.value("#E6E6E6") |
| 221 | + ).properties( |
| 222 | + height=matrix_height |
| 223 | + ) |
| 224 | + |
| 225 | + rect_bg = circle_bg.mark_rect().transform_filter( |
| 226 | + alt.datum["set_order"] % 2 == 1 |
| 227 | + ).encode( |
| 228 | + color=alt.value("#F7F7F7") |
| 229 | + ) |
| 230 | + |
| 231 | + circle = circle_bg.transform_filter( |
| 232 | + alt.datum["is_intersect"] == 1 |
| 233 | + ).encode( |
| 234 | + color=brush_color |
| 235 | + ) |
| 236 | + |
| 237 | + line_connection = vertical_bar.mark_bar(size=line_connection_size, color=main_color).transform_filter( |
| 238 | + alt.datum["is_intersect"] == 1 |
| 239 | + ).encode( |
| 240 | + y=alt.Y("min(set_order):N"), |
| 241 | + y2=alt.Y2("max(set_order):N") |
| 242 | + ) |
| 243 | + |
| 244 | + matrix_view = (circle + rect_bg + circle_bg + line_connection + circle).add_selection( |
| 245 | + # Duplicate `circle` is to properly show tooltips. |
| 246 | + color_selection |
| 247 | + ) |
| 248 | + |
| 249 | + # Cardinality by sets (horizontal bar chart) |
| 250 | + horizontal_bar_label_bg = base.mark_circle(size=set_label_bg_size).encode( |
| 251 | + y=alt.Y( |
| 252 | + "set_order:N", |
| 253 | + axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False), |
| 254 | + title=None, |
| 255 | + ), |
| 256 | + color=alt.Color( |
| 257 | + "set:N", |
| 258 | + scale=alt.Scale(domain=sets, range=color_range), |
| 259 | + title=None |
| 260 | + ), |
| 261 | + opacity=alt.value(1) |
| 262 | + ) |
| 263 | + horizontal_bar_label = horizontal_bar_label_bg.mark_text( |
| 264 | + align=("center" if is_show_horizontal_bar_label_bg else "center") |
| 265 | + ).encode( |
| 266 | + text=alt.Text("set_abbre:N"), |
| 267 | + color=alt.value(horizontal_bar_label_bg_color) |
| 268 | + ) |
| 269 | + horizontal_bar_axis = (horizontal_bar_label_bg + horizontal_bar_label) if is_show_horizontal_bar_label_bg else horizontal_bar_label |
| 270 | + |
| 271 | + horizontal_bar = horizontal_bar_label_bg.mark_bar( |
| 272 | + size=horizontal_bar_size |
| 273 | + ).transform_filter( |
| 274 | + alt.datum["is_intersect"] == 1 |
| 275 | + ).encode( |
| 276 | + x=alt.X( |
| 277 | + "sum(count):Q", |
| 278 | + axis=alt.Axis(grid=False, tickCount=3), |
| 279 | + title="Set Size" |
| 280 | + ) |
| 281 | + ).properties( |
| 282 | + width=horizontal_bar_chart_width |
| 283 | + ) |
| 284 | + |
| 285 | + # Concat Plots |
| 286 | + upsetaltair = alt.vconcat( |
| 287 | + vertical_bar_chart, |
| 288 | + alt.hconcat( |
| 289 | + matrix_view, |
| 290 | + horizontal_bar_axis, horizontal_bar, # horizontal bar chart |
| 291 | + spacing=5 |
| 292 | + ).resolve_scale( |
| 293 | + y="shared" |
| 294 | + ), |
| 295 | + spacing=20 |
| 296 | + ).add_selection( |
| 297 | + legend_selection |
| 298 | + ) |
| 299 | + |
| 300 | + # Apply top-level configuration |
| 301 | + upsetaltair = upsetaltair_top_level_configuration( |
| 302 | + upsetaltair, |
| 303 | + legend_orient="top", |
| 304 | + legend_symbol_size=set_label_bg_size / 2.0 |
| 305 | + ).properties( |
| 306 | + title={ |
| 307 | + "text": title, |
| 308 | + "subtitle": subtitle, |
| 309 | + "fontSize": 20, |
| 310 | + "fontWeight": 500, |
| 311 | + "subtitleColor": main_color, |
| 312 | + "subtitleFontSize": 14 |
| 313 | + } |
| 314 | + ) |
| 315 | + |
| 316 | + return upsetaltair |
| 317 | + |
| 318 | +# Top-level altair configuration |
| 319 | +def upsetaltair_top_level_configuration( |
| 320 | + base, |
| 321 | + legend_orient="top-left", |
| 322 | + legend_symbol_size=30 |
| 323 | +): |
| 324 | + return base.configure_view( |
| 325 | + stroke=None |
| 326 | + ).configure_title( |
| 327 | + fontSize=18, |
| 328 | + fontWeight=400, |
| 329 | + anchor="start", |
| 330 | + subtitlePadding=10 |
| 331 | + ).configure_axis( |
| 332 | + labelFontSize=14, |
| 333 | + labelFontWeight=300, |
| 334 | + titleFontSize=16, |
| 335 | + titleFontWeight=400, |
| 336 | + titlePadding=10 |
| 337 | + ).configure_legend( |
| 338 | + titleFontSize=16, |
| 339 | + titleFontWeight=400, |
| 340 | + labelFontSize=14, |
| 341 | + labelFontWeight=300, |
| 342 | + padding=20, |
| 343 | + orient=legend_orient, |
| 344 | + symbolType="circle", |
| 345 | + symbolSize=legend_symbol_size, |
| 346 | + ).configure_concat( |
| 347 | + spacing=0 |
| 348 | + ) |
| 349 | + |
| 350 | +if __name__ == '__main__': |
| 351 | + |
| 352 | + # Use the latest data from https://figshare.com/articles/covid_symptoms_table_csv/12148893 |
| 353 | + df = pd.read_csv("https://ndownloader.figshare.com/files/22339791") |
| 354 | + |
| 355 | + upset_altair = visualize( |
| 356 | + data=df.copy(), |
| 357 | + title="Symptoms Reported by Users of the COVID Symptom Tracker App", |
| 358 | + subtitle=[ |
| 359 | + "Story & Data: https://www.nature.com/articles/d41586-020-00154-w", |
| 360 | + "Altair-based UpSet Plot: https://github.com/hms-dbmi/upset-altair-notebook" |
| 361 | + ], |
| 362 | + sets=["Shortness of Breath", "Diarrhea", "Fever", "Cough", "Anosmia", "Fatigue"], |
| 363 | + abbre=["B", "D", "Fe", "C", "A", "Fa"], |
| 364 | + sort_by="frequency", |
| 365 | + sort_order="ascending", |
| 366 | + ) |
| 367 | + |
| 368 | + upset_altair.display() |
0 commit comments