@@ -266,8 +266,16 @@ def _h12_gwss(
266
266
# Compute window midpoints.
267
267
pos = ds_haps ["variant_position" ].values
268
268
x = allel .moving_statistic (pos , statistic = np .mean , size = window_size )
269
+ contigs = np .asarray (
270
+ allel .moving_statistic (
271
+ ds_haps ["variant_contig" ].values ,
272
+ statistic = np .median ,
273
+ size = window_size ,
274
+ ),
275
+ dtype = int ,
276
+ )
269
277
270
- results = dict (x = x , h12 = h12 )
278
+ results = dict (x = x , h12 = h12 , contigs = contigs )
271
279
272
280
return results
273
281
@@ -277,6 +285,7 @@ def _h12_gwss(
277
285
returns = dict (
278
286
x = "An array containing the window centre point genomic positions." ,
279
287
h12 = "An array with h12 statistic values for each window." ,
288
+ contigs = "An array with the contig for each window. The median is chosen for windows overlapping a change of contig." ,
280
289
),
281
290
)
282
291
def h12_gwss (
@@ -297,10 +306,10 @@ def h12_gwss(
297
306
random_seed : base_params .random_seed = 42 ,
298
307
chunks : base_params .chunks = base_params .native_chunks ,
299
308
inline_array : base_params .inline_array = base_params .inline_array_default ,
300
- ) -> Tuple [np .ndarray , np .ndarray ]:
309
+ ) -> Tuple [np .ndarray , np .ndarray , np . ndarray ]:
301
310
# Change this name if you ever change the behaviour of this function, to
302
311
# invalidate any previously cached data.
303
- name = "h12_gwss_v1 "
312
+ name = "h12_gwss_contig_v1 "
304
313
305
314
params = dict (
306
315
contig = contig ,
@@ -327,8 +336,9 @@ def h12_gwss(
327
336
328
337
x = results ["x" ]
329
338
h12 = results ["h12" ]
339
+ contigs = results ["contigs" ]
330
340
331
- return x , h12
341
+ return x , h12 , contigs
332
342
333
343
@check_types
334
344
@doc (
@@ -354,14 +364,15 @@ def plot_h12_gwss_track(
354
364
sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
355
365
width : gplt_params .width = gplt_params .width_default ,
356
366
height : gplt_params .height = 200 ,
367
+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
357
368
show : gplt_params .show = True ,
358
369
x_range : Optional [gplt_params .x_range ] = None ,
359
370
output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
360
371
chunks : base_params .chunks = base_params .native_chunks ,
361
372
inline_array : base_params .inline_array = base_params .inline_array_default ,
362
373
) -> gplt_params .figure :
363
374
# Compute H12.
364
- x , h12 = self .h12_gwss (
375
+ x , h12 , contigs = self .h12_gwss (
365
376
contig = contig ,
366
377
analysis = analysis ,
367
378
window_size = window_size ,
@@ -412,15 +423,14 @@ def plot_h12_gwss_track(
412
423
)
413
424
414
425
# Plot H12.
415
- fig .scatter (
416
- x = x ,
417
- y = h12 ,
418
- marker = "circle" ,
419
- size = 3 ,
420
- line_width = 1 ,
421
- line_color = "black" ,
422
- fill_color = None ,
423
- )
426
+ for s in set (contigs ):
427
+ idxs = contigs == s
428
+ fig .scatter (
429
+ x = x [idxs ],
430
+ y = h12 [idxs ],
431
+ marker = "circle" ,
432
+ color = contig_colors [s % len (contig_colors )],
433
+ )
424
434
425
435
# Tidy up the plot.
426
436
fig .yaxis .axis_label = "H12"
@@ -457,6 +467,7 @@ def plot_h12_gwss(
457
467
sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
458
468
width : gplt_params .width = gplt_params .width_default ,
459
469
track_height : gplt_params .track_height = 170 ,
470
+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
460
471
genes_height : gplt_params .genes_height = gplt_params .genes_height_default ,
461
472
show : gplt_params .show = True ,
462
473
output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
@@ -479,6 +490,7 @@ def plot_h12_gwss(
479
490
sizing_mode = sizing_mode ,
480
491
width = width ,
481
492
height = track_height ,
493
+ contig_colors = contig_colors ,
482
494
show = False ,
483
495
output_backend = output_backend ,
484
496
chunks = chunks ,
@@ -575,7 +587,7 @@ def plot_h12_gwss_multi_overlay_track(
575
587
)
576
588
577
589
# Determine X axis range.
578
- x , _ = res [list (cohort_queries .keys ())[0 ]]
590
+ x , _ , _ = res [list (cohort_queries .keys ())[0 ]]
579
591
x_min = x [0 ]
580
592
x_max = x [- 1 ]
581
593
if x_range is None :
@@ -610,7 +622,7 @@ def plot_h12_gwss_multi_overlay_track(
610
622
)
611
623
612
624
# Plot H12.
613
- for i , (cohort_label , (x , h12 )) in enumerate (res .items ()):
625
+ for i , (cohort_label , (x , h12 , contig )) in enumerate (res .items ()):
614
626
fig .scatter (
615
627
x = x ,
616
628
y = h12 ,
0 commit comments