@@ -265,8 +265,16 @@ def _h12_gwss(
265
265
# Compute window midpoints.
266
266
pos = ds_haps ["variant_position" ].values
267
267
x = allel .moving_statistic (pos , statistic = np .mean , size = window_size )
268
+ contigs = np .asarray (
269
+ allel .moving_statistic (
270
+ ds_haps ["variant_contig" ].values ,
271
+ statistic = np .median ,
272
+ size = window_size ,
273
+ ),
274
+ dtype = int ,
275
+ )
268
276
269
- results = dict (x = x , h12 = h12 )
277
+ results = dict (x = x , h12 = h12 , contigs = contigs )
270
278
271
279
return results
272
280
@@ -276,6 +284,7 @@ def _h12_gwss(
276
284
returns = dict (
277
285
x = "An array containing the window centre point genomic positions." ,
278
286
h12 = "An array with h12 statistic values for each window." ,
287
+ contigs = "An array with the contig for each window. The median is chosen for windows overlapping a change of contig." ,
279
288
),
280
289
)
281
290
def h12_gwss (
@@ -296,10 +305,10 @@ def h12_gwss(
296
305
random_seed : base_params .random_seed = 42 ,
297
306
chunks : base_params .chunks = base_params .native_chunks ,
298
307
inline_array : base_params .inline_array = base_params .inline_array_default ,
299
- ) -> Tuple [np .ndarray , np .ndarray ]:
308
+ ) -> Tuple [np .ndarray , np .ndarray , np . ndarray ]:
300
309
# Change this name if you ever change the behaviour of this function, to
301
310
# invalidate any previously cached data.
302
- name = "h12_gwss_v1 "
311
+ name = "h12_gwss_v2 "
303
312
304
313
params = dict (
305
314
contig = contig ,
@@ -326,8 +335,9 @@ def h12_gwss(
326
335
327
336
x = results ["x" ]
328
337
h12 = results ["h12" ]
338
+ contigs = results ["contigs" ]
329
339
330
- return x , h12
340
+ return x , h12 , contigs
331
341
332
342
@check_types
333
343
@doc (
@@ -353,14 +363,15 @@ def plot_h12_gwss_track(
353
363
sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
354
364
width : gplt_params .width = gplt_params .width_default ,
355
365
height : gplt_params .height = 200 ,
366
+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
356
367
show : gplt_params .show = True ,
357
368
x_range : Optional [gplt_params .x_range ] = None ,
358
369
output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
359
370
chunks : base_params .chunks = base_params .native_chunks ,
360
371
inline_array : base_params .inline_array = base_params .inline_array_default ,
361
372
) -> gplt_params .figure :
362
373
# Compute H12.
363
- x , h12 = self .h12_gwss (
374
+ x , h12 , contigs = self .h12_gwss (
364
375
contig = contig ,
365
376
analysis = analysis ,
366
377
window_size = window_size ,
@@ -411,15 +422,14 @@ def plot_h12_gwss_track(
411
422
)
412
423
413
424
# Plot H12.
414
- fig .scatter (
415
- x = x ,
416
- y = h12 ,
417
- marker = "circle" ,
418
- size = 3 ,
419
- line_width = 1 ,
420
- line_color = "black" ,
421
- fill_color = None ,
422
- )
425
+ for s in set (contigs ):
426
+ idxs = contigs == s
427
+ fig .scatter (
428
+ x = x [idxs ],
429
+ y = h12 [idxs ],
430
+ marker = "circle" ,
431
+ color = contig_colors [s % len (contig_colors )],
432
+ )
423
433
424
434
# Tidy up the plot.
425
435
fig .yaxis .axis_label = "H12"
@@ -456,6 +466,7 @@ def plot_h12_gwss(
456
466
sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
457
467
width : gplt_params .width = gplt_params .width_default ,
458
468
track_height : gplt_params .track_height = 170 ,
469
+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
459
470
genes_height : gplt_params .genes_height = gplt_params .genes_height_default ,
460
471
show : gplt_params .show = True ,
461
472
output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
@@ -478,6 +489,7 @@ def plot_h12_gwss(
478
489
sizing_mode = sizing_mode ,
479
490
width = width ,
480
491
height = track_height ,
492
+ contig_colors = contig_colors ,
481
493
show = False ,
482
494
output_backend = output_backend ,
483
495
chunks = chunks ,
@@ -574,7 +586,7 @@ def plot_h12_gwss_multi_overlay_track(
574
586
)
575
587
576
588
# Determine X axis range.
577
- x , _ = res [list (cohort_queries .keys ())[0 ]]
589
+ x , _ , _ = res [list (cohort_queries .keys ())[0 ]]
578
590
x_min = x [0 ]
579
591
x_max = x [- 1 ]
580
592
if x_range is None :
@@ -609,7 +621,7 @@ def plot_h12_gwss_multi_overlay_track(
609
621
)
610
622
611
623
# Plot H12.
612
- for i , (cohort_label , (x , h12 )) in enumerate (res .items ()):
624
+ for i , (cohort_label , (x , h12 , contig )) in enumerate (res .items ()):
613
625
fig .scatter (
614
626
x = x ,
615
627
y = h12 ,
0 commit comments