Skip to content

Commit 73ae392

Browse files
authored
Merge pull request #591 from malariagen/583-GWSS-colors
Colors in GWSS plots
2 parents b33e3e4 + 5a9c28a commit 73ae392

File tree

6 files changed

+131
-46
lines changed

6 files changed

+131
-46
lines changed

malariagen_data/anoph/gplt_params.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Parameters for genome plotting functions. N.B., genome plots are always
22
plotted with bokeh."""
33

4-
from typing import Literal, Mapping, Optional, Union, Sequence
4+
from typing import Literal, Mapping, Optional, Union, Final, Sequence
55

66
import bokeh.models
77
from typing_extensions import Annotated, TypeAlias
@@ -112,4 +112,11 @@
112112
"Passed through to bokeh line() function.",
113113
]
114114

115+
contig_colors: TypeAlias = Annotated[
116+
list[str],
117+
"A sequence of colors.",
118+
]
119+
120+
contig_colors_default: Final[contig_colors] = list(bokeh.palettes.d3["Category20b"][5])
121+
115122
colors: TypeAlias = Annotated[Sequence[str], "List of colors."]

malariagen_data/anoph/h12.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,16 @@ def _h12_gwss(
266266
# Compute window midpoints.
267267
pos = ds_haps["variant_position"].values
268268
x = allel.moving_statistic(pos, statistic=np.mean, size=window_size)
269+
contigs = np.asarray(
270+
allel.moving_statistic(
271+
ds_haps["variant_contig"].values,
272+
statistic=np.median,
273+
size=window_size,
274+
),
275+
dtype=int,
276+
)
269277

270-
results = dict(x=x, h12=h12)
278+
results = dict(x=x, h12=h12, contigs=contigs)
271279

272280
return results
273281

@@ -277,6 +285,7 @@ def _h12_gwss(
277285
returns=dict(
278286
x="An array containing the window centre point genomic positions.",
279287
h12="An array with h12 statistic values for each window.",
288+
contigs="An array with the contig for each window. The median is chosen for windows overlapping a change of contig.",
280289
),
281290
)
282291
def h12_gwss(
@@ -297,10 +306,10 @@ def h12_gwss(
297306
random_seed: base_params.random_seed = 42,
298307
chunks: base_params.chunks = base_params.native_chunks,
299308
inline_array: base_params.inline_array = base_params.inline_array_default,
300-
) -> Tuple[np.ndarray, np.ndarray]:
309+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
301310
# Change this name if you ever change the behaviour of this function, to
302311
# invalidate any previously cached data.
303-
name = "h12_gwss_v1"
312+
name = "h12_gwss_contig_v1"
304313

305314
params = dict(
306315
contig=contig,
@@ -327,8 +336,9 @@ def h12_gwss(
327336

328337
x = results["x"]
329338
h12 = results["h12"]
339+
contigs = results["contigs"]
330340

331-
return x, h12
341+
return x, h12, contigs
332342

333343
@check_types
334344
@doc(
@@ -354,14 +364,15 @@ def plot_h12_gwss_track(
354364
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
355365
width: gplt_params.width = gplt_params.width_default,
356366
height: gplt_params.height = 200,
367+
contig_colors: gplt_params.contig_colors = gplt_params.contig_colors_default,
357368
show: gplt_params.show = True,
358369
x_range: Optional[gplt_params.x_range] = None,
359370
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
360371
chunks: base_params.chunks = base_params.native_chunks,
361372
inline_array: base_params.inline_array = base_params.inline_array_default,
362373
) -> gplt_params.figure:
363374
# Compute H12.
364-
x, h12 = self.h12_gwss(
375+
x, h12, contigs = self.h12_gwss(
365376
contig=contig,
366377
analysis=analysis,
367378
window_size=window_size,
@@ -412,15 +423,14 @@ def plot_h12_gwss_track(
412423
)
413424

414425
# Plot H12.
415-
fig.scatter(
416-
x=x,
417-
y=h12,
418-
marker="circle",
419-
size=3,
420-
line_width=1,
421-
line_color="black",
422-
fill_color=None,
423-
)
426+
for s in set(contigs):
427+
idxs = contigs == s
428+
fig.scatter(
429+
x=x[idxs],
430+
y=h12[idxs],
431+
marker="circle",
432+
color=contig_colors[s % len(contig_colors)],
433+
)
424434

425435
# Tidy up the plot.
426436
fig.yaxis.axis_label = "H12"
@@ -457,6 +467,7 @@ def plot_h12_gwss(
457467
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
458468
width: gplt_params.width = gplt_params.width_default,
459469
track_height: gplt_params.track_height = 170,
470+
contig_colors: gplt_params.contig_colors = gplt_params.contig_colors_default,
460471
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
461472
show: gplt_params.show = True,
462473
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
@@ -479,6 +490,7 @@ def plot_h12_gwss(
479490
sizing_mode=sizing_mode,
480491
width=width,
481492
height=track_height,
493+
contig_colors=contig_colors,
482494
show=False,
483495
output_backend=output_backend,
484496
chunks=chunks,
@@ -575,7 +587,7 @@ def plot_h12_gwss_multi_overlay_track(
575587
)
576588

577589
# Determine X axis range.
578-
x, _ = res[list(cohort_queries.keys())[0]]
590+
x, _, _ = res[list(cohort_queries.keys())[0]]
579591
x_min = x[0]
580592
x_max = x[-1]
581593
if x_range is None:
@@ -610,7 +622,7 @@ def plot_h12_gwss_multi_overlay_track(
610622
)
611623

612624
# Plot H12.
613-
for i, (cohort_label, (x, h12)) in enumerate(res.items()):
625+
for i, (cohort_label, (x, h12, contig)) in enumerate(res.items()):
614626
fig.scatter(
615627
x=x,
616628
y=h12,

malariagen_data/anoph/h1x.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,16 @@ def _h1x_gwss(
8484
# Compute window midpoints.
8585
pos = ds1["variant_position"].values
8686
x = allel.moving_statistic(pos, statistic=np.mean, size=window_size)
87+
contigs = np.asarray(
88+
allel.moving_statistic(
89+
ds1["variant_contig"].values,
90+
statistic=np.median,
91+
size=window_size,
92+
),
93+
dtype=int,
94+
)
8795

88-
results = dict(x=x, h1x=h1x)
96+
results = dict(x=x, h1x=h1x, contigs=contigs)
8997

9098
return results
9199

@@ -98,6 +106,7 @@ def _h1x_gwss(
98106
returns=dict(
99107
x="An array containing the window centre point genomic positions.",
100108
h1x="An array with H1X statistic values for each window.",
109+
contigs="An array with the contig for each window. The median is chosen for windows overlapping a change of contig.",
101110
),
102111
)
103112
def h1x_gwss(
@@ -119,10 +128,10 @@ def h1x_gwss(
119128
random_seed: base_params.random_seed = 42,
120129
chunks: base_params.chunks = base_params.native_chunks,
121130
inline_array: base_params.inline_array = base_params.inline_array_default,
122-
) -> Tuple[np.ndarray, np.ndarray]:
131+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
123132
# Change this name if you ever change the behaviour of this function, to
124133
# invalidate any previously cached data.
125-
name = "h1x_gwss_v1"
134+
name = "h1x_gwss_contig_v1"
126135

127136
params = dict(
128137
contig=contig,
@@ -150,8 +159,9 @@ def h1x_gwss(
150159

151160
x = results["x"]
152161
h1x = results["h1x"]
162+
contigs = results["contigs"]
153163

154-
return x, h1x
164+
return x, h1x, contigs
155165

156166
@check_types
157167
@doc(
@@ -181,14 +191,15 @@ def plot_h1x_gwss_track(
181191
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
182192
width: gplt_params.width = gplt_params.width_default,
183193
height: gplt_params.height = 200,
194+
contig_colors: gplt_params.contig_colors = gplt_params.contig_colors_default,
184195
show: gplt_params.show = True,
185196
x_range: Optional[gplt_params.x_range] = None,
186197
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
187198
chunks: base_params.chunks = base_params.native_chunks,
188199
inline_array: base_params.inline_array = base_params.inline_array_default,
189200
) -> gplt_params.figure:
190201
# Compute H1X.
191-
x, h1x = self.h1x_gwss(
202+
x, h1x, contigs = self.h1x_gwss(
192203
contig=contig,
193204
analysis=analysis,
194205
window_size=window_size,
@@ -240,15 +251,14 @@ def plot_h1x_gwss_track(
240251
)
241252

242253
# Plot H1X.
243-
fig.scatter(
244-
x=x,
245-
y=h1x,
246-
marker="circle",
247-
size=3,
248-
line_width=1,
249-
line_color="black",
250-
fill_color=None,
251-
)
254+
for s in set(contigs):
255+
idxs = contigs == s
256+
fig.scatter(
257+
x=x[idxs],
258+
y=h1x[idxs],
259+
marker="circle",
260+
color=contig_colors[s % len(contig_colors)],
261+
)
252262

253263
# Tidy up the plot.
254264
fig.yaxis.axis_label = "H1X"
@@ -289,6 +299,7 @@ def plot_h1x_gwss(
289299
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
290300
width: gplt_params.width = gplt_params.width_default,
291301
track_height: gplt_params.track_height = 190,
302+
contig_colors: gplt_params.contig_colors = gplt_params.contig_colors_default,
292303
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
293304
show: gplt_params.show = True,
294305
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
@@ -312,6 +323,7 @@ def plot_h1x_gwss(
312323
sizing_mode=sizing_mode,
313324
width=width,
314325
height=track_height,
326+
contig_colors=contig_colors,
315327
show=False,
316328
output_backend=output_backend,
317329
chunks=chunks,

notebooks/plot_h12_h1x.ipynb

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@
7575
"coh2 = \"ML-2_Kati_gamb_2014\"\n",
7676
"coh1_query = f\"cohort_admin2_year == '{coh1}'\"\n",
7777
"coh2_query = f\"cohort_admin2_year == '{coh2}'\"\n",
78-
"contig = \"2L\""
78+
"contig = \"2L\"\n",
79+
"contigs = \"2RL\""
7980
]
8081
},
8182
{
@@ -114,6 +115,23 @@
114115
")"
115116
]
116117
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"id": "4470d24c-8cf1-4d22-b774-0121b4560e27",
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"ag3.plot_h12_gwss(\n",
126+
" contig=contigs,\n",
127+
" analysis=\"gamb_colu\",\n",
128+
" window_size=2000,\n",
129+
" sample_query=coh1_query,\n",
130+
" sample_sets=\"3.0\",\n",
131+
" cohort_size=20,\n",
132+
")"
133+
]
134+
},
117135
{
118136
"cell_type": "code",
119137
"execution_count": null,
@@ -173,6 +191,25 @@
173191
")"
174192
]
175193
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"id": "b4b7a8d2-95d0-48dc-a32f-3bc96aacfb9f",
198+
"metadata": {},
199+
"outputs": [],
200+
"source": [
201+
"ag3.plot_h1x_gwss(\n",
202+
" contig=contigs,\n",
203+
" window_size=2000,\n",
204+
" cohort1_query=coh1_query,\n",
205+
" cohort2_query=coh2_query,\n",
206+
" sample_sets=\"3.0\",\n",
207+
" analysis=\"gamb_colu\",\n",
208+
" cohort_size=20,\n",
209+
" contig_colors=[\"red\", \"green\"]\n",
210+
")"
211+
]
212+
},
176213
{
177214
"cell_type": "code",
178215
"execution_count": null,
@@ -261,6 +298,22 @@
261298
"contig = \"2RL\""
262299
]
263300
},
301+
{
302+
"cell_type": "code",
303+
"execution_count": null,
304+
"id": "1aaa0573-723c-43b1-baea-750172c4dabc",
305+
"metadata": {},
306+
"outputs": [],
307+
"source": []
308+
},
309+
{
310+
"cell_type": "code",
311+
"execution_count": null,
312+
"id": "ffc7dc06-6bdb-42d2-a1fb-878612d10dd1",
313+
"metadata": {},
314+
"outputs": [],
315+
"source": []
316+
},
264317
{
265318
"cell_type": "code",
266319
"execution_count": null,
@@ -364,14 +417,6 @@
364417
" cohort_size=20,\n",
365418
")"
366419
]
367-
},
368-
{
369-
"cell_type": "code",
370-
"execution_count": null,
371-
"id": "67e3bfcc",
372-
"metadata": {},
373-
"outputs": [],
374-
"source": []
375420
}
376421
],
377422
"metadata": {
@@ -382,7 +427,7 @@
382427
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m125"
383428
},
384429
"kernelspec": {
385-
"display_name": "malariagen-data-python",
430+
"display_name": "Python 3 (ipykernel)",
386431
"language": "python",
387432
"name": "python3"
388433
},
@@ -396,7 +441,7 @@
396441
"name": "python",
397442
"nbconvert_exporter": "python",
398443
"pygments_lexer": "ipython3",
399-
"version": "3.10.15"
444+
"version": "3.10.11"
400445
},
401446
"vscode": {
402447
"interpreter": {

tests/anoph/test_h12.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,23 +129,28 @@ def test_h12_calibration(fixture, api: AnophelesH12Analysis):
129129

130130
def check_h12_gwss(*, api, h12_params):
131131
# Run main gwss function under test.
132-
x, h12 = api.h12_gwss(**h12_params)
132+
133+
x, h12, contigs = api.h12_gwss(**h12_params)
133134

134135
# Check results.
135136
assert isinstance(x, np.ndarray)
136137
assert isinstance(h12, np.ndarray)
138+
assert isinstance(contigs, np.ndarray)
137139
assert x.ndim == 1
138140
assert h12.ndim == 1
141+
assert contigs.ndim == 1
139142
assert x.shape == h12.shape
143+
assert x.shape == contigs.shape
140144
assert x.dtype.kind == "f"
141145
assert h12.dtype.kind == "f"
146+
assert contigs.dtype.kind == "i"
142147
assert np.all(h12 >= 0)
143148
assert np.all(h12 <= 1)
144149

145150
# Check plotting functions.
146151
fig = api.plot_h12_gwss_track(**h12_params, show=False)
147152
assert isinstance(fig, bokeh.models.Plot)
148-
fig = api.plot_h12_gwss(**h12_params, show=False)
153+
fig = api.plot_h12_gwss(**h12_params, contig_colors=["black", "red"], show=False)
149154
assert isinstance(fig, bokeh.models.GridPlot)
150155

151156

0 commit comments

Comments
 (0)