33Arrow Demo
44==========
55
6- Arrow drawing example for the new fancy_arrow facilities.
7-
8- Code contributed by: Rob Knight <[email protected] > 9-
10- usage:
11-
12- python arrow_demo.py realistic|full|sample|extreme
6+ Three ways of drawing arrows to encode arrow "strength" (e.g., transition
7+ probabilities in a Markov model) using arrow length, width, or alpha (opacity).
8+ """
139
10+ import itertools
1411
15- """
1612import matplotlib .pyplot as plt
1713import numpy as np
1814
19- rates_to_bases = {'r1' : 'AT' , 'r2' : 'TA' , 'r3' : 'GA' , 'r4' : 'AG' , 'r5' : 'CA' ,
20- 'r6' : 'AC' , 'r7' : 'GT' , 'r8' : 'TG' , 'r9' : 'CT' , 'r10' : 'TC' ,
21- 'r11' : 'GC' , 'r12' : 'CG' }
22- numbered_bases_to_rates = {v : k for k , v in rates_to_bases .items ()}
23- lettered_bases_to_rates = {v : 'r' + v for k , v in rates_to_bases .items ()}
24-
2515
26- def make_arrow_plot (data , size = 4 , display = 'length' , shape = 'right' ,
27- max_arrow_width = 0.03 , arrow_sep = 0.02 , alpha = 0.5 ,
28- normalize_data = False , ec = None , labelcolor = None ,
29- head_starts_at_zero = True ,
30- rate_labels = lettered_bases_to_rates ,
31- ** kwargs ):
16+ def make_arrow_graph (ax , data , size = 4 , display = 'length' , shape = 'right' ,
17+ max_arrow_width = 0.03 , arrow_sep = 0.02 , alpha = 0.5 ,
18+ normalize_data = False , ec = None , labelcolor = None ,
19+ ** kwargs ):
3220 """
3321 Makes an arrow plot.
3422
3523 Parameters
3624 ----------
25+ ax
26+ The axes where the graph is drawn.
3727 data
3828 Dict with probabilities for the bases and pair transitions.
3929 size
40- Size of the graph in inches.
30+ Size of the plot, in inches.
4131 display : {'length', 'width', 'alpha'}
4232 The arrow property to change.
4333 shape : {'full', 'left', 'right'}
4434 For full or half arrows.
4535 max_arrow_width : float
46- Maximum width of an arrow, data coordinates.
36+ Maximum width of an arrow, in data coordinates.
4737 arrow_sep : float
48- Separation between arrows in a pair, data coordinates.
38+ Separation between arrows in a pair, in data coordinates.
4939 alpha : float
5040 Maximum opacity of arrows.
5141 **kwargs
52- Can be anything allowed by a Arrow object, e.g. *linewidth* or
53- *edgecolor*.
42+ `.FancyArrow` properties, e.g. *linewidth* or *edgecolor*.
5443 """
5544
56- plt .xlim (- 0.5 , 1.5 )
57- plt .ylim (- 0.5 , 1.5 )
58- plt .gcf ().set_size_inches (size , size )
59- plt .xticks ([])
60- plt .yticks ([])
45+ ax .set (xlim = (- 0.5 , 1.5 ), ylim = (- 0.5 , 1.5 ), xticks = [], yticks = [])
46+ ax .text (.01 , .01 , f'flux encoded as arrow { display } ' ,
47+ transform = ax .transAxes )
6148 max_text_size = size * 12
6249 min_text_size = size
6350 label_text_size = size * 2.5
64- text_params = {'ha' : 'center' , 'va' : 'center' , 'family' : 'sans-serif' ,
65- 'fontweight' : 'bold' }
66- r2 = np .sqrt (2 )
67-
68- deltas = {
69- 'AT' : (1 , 0 ),
70- 'TA' : (- 1 , 0 ),
71- 'GA' : (0 , 1 ),
72- 'AG' : (0 , - 1 ),
73- 'CA' : (- 1 / r2 , 1 / r2 ),
74- 'AC' : (1 / r2 , - 1 / r2 ),
75- 'GT' : (1 / r2 , 1 / r2 ),
76- 'TG' : (- 1 / r2 , - 1 / r2 ),
77- 'CT' : (0 , 1 ),
78- 'TC' : (0 , - 1 ),
79- 'GC' : (1 , 0 ),
80- 'CG' : (- 1 , 0 )}
8151
82- colors = {
83- 'AT' : 'r' ,
84- 'TA' : 'k' ,
85- 'GA' : 'g' ,
86- 'AG' : 'r' ,
87- 'CA' : 'b' ,
88- 'AC' : 'r' ,
89- 'GT' : 'g' ,
90- 'TG' : 'k' ,
91- 'CT' : 'b' ,
92- 'TC' : 'k' ,
93- 'GC' : 'g' ,
94- 'CG' : 'b' }
95-
96- label_positions = {
97- 'AT' : 'center' ,
98- 'TA' : 'center' ,
99- 'GA' : 'center' ,
100- 'AG' : 'center' ,
101- 'CA' : 'left' ,
102- 'AC' : 'left' ,
103- 'GT' : 'left' ,
104- 'TG' : 'left' ,
105- 'CT' : 'center' ,
106- 'TC' : 'center' ,
107- 'GC' : 'center' ,
108- 'CG' : 'center' }
109-
110- def do_fontsize (k ):
111- return float (np .clip (max_text_size * np .sqrt (data [k ]),
112- min_text_size , max_text_size ))
113-
114- plt .text (0 , 1 , '$A_3$' , color = 'r' , size = do_fontsize ('A' ), ** text_params )
115- plt .text (1 , 1 , '$T_3$' , color = 'k' , size = do_fontsize ('T' ), ** text_params )
116- plt .text (0 , 0 , '$G_3$' , color = 'g' , size = do_fontsize ('G' ), ** text_params )
117- plt .text (1 , 0 , '$C_3$' , color = 'b' , size = do_fontsize ('C' ), ** text_params )
52+ bases = 'ATGC'
53+ coords = {
54+ 'A' : np .array ([0 , 1 ]),
55+ 'T' : np .array ([1 , 1 ]),
56+ 'G' : np .array ([0 , 0 ]),
57+ 'C' : np .array ([1 , 0 ]),
58+ }
59+ colors = {'A' : 'r' , 'T' : 'k' , 'G' : 'g' , 'C' : 'b' }
60+
61+ for base in bases :
62+ fontsize = np .clip (max_text_size * data [base ]** (1 / 2 ),
63+ min_text_size , max_text_size )
64+ ax .text (* coords [base ], f'${ base } _3$' ,
65+ color = colors [base ], size = fontsize ,
66+ horizontalalignment = 'center' , verticalalignment = 'center' ,
67+ weight = 'bold' )
11868
11969 arrow_h_offset = 0.25 # data coordinates, empirically determined
12070 max_arrow_length = 1 - 2 * arrow_h_offset
12171 max_head_width = 2.5 * max_arrow_width
12272 max_head_length = 2 * max_arrow_width
123- arrow_params = {'length_includes_head' : True , 'shape' : shape ,
124- 'head_starts_at_zero' : head_starts_at_zero }
12573 sf = 0.6 # max arrow size represents this in data coords
12674
127- d = (r2 / 2 + arrow_h_offset - 0.5 ) / r2 # distance for diags
128- r2v = arrow_sep / r2 # offset for diags
129-
130- # tuple of x, y for start position
131- positions = {
132- 'AT' : (arrow_h_offset , 1 + arrow_sep ),
133- 'TA' : (1 - arrow_h_offset , 1 - arrow_sep ),
134- 'GA' : (- arrow_sep , arrow_h_offset ),
135- 'AG' : (arrow_sep , 1 - arrow_h_offset ),
136- 'CA' : (1 - d - r2v , d - r2v ),
137- 'AC' : (d + r2v , 1 - d + r2v ),
138- 'GT' : (d - r2v , d + r2v ),
139- 'TG' : (1 - d + r2v , 1 - d - r2v ),
140- 'CT' : (1 - arrow_sep , arrow_h_offset ),
141- 'TC' : (1 + arrow_sep , 1 - arrow_h_offset ),
142- 'GC' : (arrow_h_offset , arrow_sep ),
143- 'CG' : (1 - arrow_h_offset , - arrow_sep )}
144-
14575 if normalize_data :
14676 # find maximum value for rates, i.e. where keys are 2 chars long
14777 max_val = max ((v for k , v in data .items () if len (k ) == 2 ), default = 0 )
14878 # divide rates by max val, multiply by arrow scale factor
14979 for k , v in data .items ():
15080 data [k ] = v / max_val * sf
15181
152- def draw_arrow (pair , alpha = alpha , ec = ec , labelcolor = labelcolor ):
82+ # iterate over strings 'AT', 'TA', 'AG', 'GA', etc.
83+ for pair in map ('' .join , itertools .permutations (bases , 2 )):
15384 # set the length of the arrow
15485 if display == 'length' :
15586 length = (max_head_length
@@ -159,7 +90,6 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
15990 # set the transparency of the arrow
16091 if display == 'alpha' :
16192 alpha = min (data [pair ] / sf , alpha )
162-
16393 # set the width of the arrow
16494 if display == 'width' :
16595 scale = data [pair ] / sf
@@ -171,137 +101,59 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
171101 head_width = max_head_width
172102 head_length = max_head_length
173103
174- fc = colors [pair ]
175- ec = ec or fc
176-
177- x_scale , y_scale = deltas [pair ]
178- x_pos , y_pos = positions [pair ]
179- plt .arrow (x_pos , y_pos , x_scale * length , y_scale * length ,
180- fc = fc , ec = ec , alpha = alpha , width = width ,
181- head_width = head_width , head_length = head_length ,
182- ** arrow_params )
183-
184- # figure out coordinates for text
104+ fc = colors [pair [0 ]]
105+
106+ cp0 = coords [pair [0 ]]
107+ cp1 = coords [pair [1 ]]
108+ # unit vector in arrow direction
109+ delta = cos , sin = (cp1 - cp0 ) / np .hypot (* (cp1 - cp0 ))
110+ x_pos , y_pos = (
111+ (cp0 + cp1 ) / 2 # midpoint
112+ - delta * length / 2 # half the arrow length
113+ + np .array ([- sin , cos ]) * arrow_sep # shift outwards by arrow_sep
114+ )
115+ ax .arrow (
116+ x_pos , y_pos , cos * length , sin * length ,
117+ fc = fc , ec = ec or fc , alpha = alpha , width = width ,
118+ head_width = head_width , head_length = head_length , shape = shape ,
119+ length_includes_head = True ,
120+ )
121+
122+ # figure out coordinates for text:
185123 # if drawing relative to base: x and y are same as for arrow
186124 # dx and dy are one arrow width left and up
187- # need to rotate based on direction of arrow, use x_scale and y_scale
188- # as sin x and cos x?
189- sx , cx = y_scale , x_scale
190-
191- where = label_positions [pair ]
192- if where == 'left' :
193- orig_position = 3 * np .array ([[max_arrow_width , max_arrow_width ]])
194- elif where == 'absolute' :
195- orig_position = np .array ([[max_arrow_length / 2.0 ,
196- 3 * max_arrow_width ]])
197- elif where == 'right' :
198- orig_position = np .array ([[length - 3 * max_arrow_width ,
199- 3 * max_arrow_width ]])
200- elif where == 'center' :
201- orig_position = np .array ([[length / 2.0 , 3 * max_arrow_width ]])
202- else :
203- raise ValueError ("Got unknown position parameter %s" % where )
204-
205- M = np .array ([[cx , sx ], [- sx , cx ]])
206- coords = np .dot (orig_position , M ) + [[x_pos , y_pos ]]
207- x , y = np .ravel (coords )
208- orig_label = rate_labels [pair ]
209- label = r'$%s_{_{\mathrm{%s}}}$' % (orig_label [0 ], orig_label [1 :])
210-
211- plt .text (x , y , label , size = label_text_size , ha = 'center' , va = 'center' ,
212- color = labelcolor or fc )
213-
214- for p in sorted (positions ):
215- draw_arrow (p )
216-
217-
218- # test data
219- all_on_max = dict ([(i , 1 ) for i in 'TCAG' ] +
220- [(i + j , 0.6 ) for i in 'TCAG' for j in 'TCAG' ])
221-
222- realistic_data = {
223- 'A' : 0.4 ,
224- 'T' : 0.3 ,
225- 'G' : 0.5 ,
226- 'C' : 0.2 ,
227- 'AT' : 0.4 ,
228- 'AC' : 0.3 ,
229- 'AG' : 0.2 ,
230- 'TA' : 0.2 ,
231- 'TC' : 0.3 ,
232- 'TG' : 0.4 ,
233- 'CT' : 0.2 ,
234- 'CG' : 0.3 ,
235- 'CA' : 0.2 ,
236- 'GA' : 0.1 ,
237- 'GT' : 0.4 ,
238- 'GC' : 0.1 }
239-
240- extreme_data = {
241- 'A' : 0.75 ,
242- 'T' : 0.10 ,
243- 'G' : 0.10 ,
244- 'C' : 0.05 ,
245- 'AT' : 0.6 ,
246- 'AC' : 0.3 ,
247- 'AG' : 0.1 ,
248- 'TA' : 0.02 ,
249- 'TC' : 0.3 ,
250- 'TG' : 0.01 ,
251- 'CT' : 0.2 ,
252- 'CG' : 0.5 ,
253- 'CA' : 0.2 ,
254- 'GA' : 0.1 ,
255- 'GT' : 0.4 ,
256- 'GC' : 0.2 }
257-
258- sample_data = {
259- 'A' : 0.2137 ,
260- 'T' : 0.3541 ,
261- 'G' : 0.1946 ,
262- 'C' : 0.2376 ,
263- 'AT' : 0.0228 ,
264- 'AC' : 0.0684 ,
265- 'AG' : 0.2056 ,
266- 'TA' : 0.0315 ,
267- 'TC' : 0.0629 ,
268- 'TG' : 0.0315 ,
269- 'CT' : 0.1355 ,
270- 'CG' : 0.0401 ,
271- 'CA' : 0.0703 ,
272- 'GA' : 0.1824 ,
273- 'GT' : 0.0387 ,
274- 'GC' : 0.1106 }
125+ orig_positions = {
126+ 'base' : [3 * max_arrow_width , 3 * max_arrow_width ],
127+ 'center' : [length / 2 , 3 * max_arrow_width ],
128+ 'tip' : [length - 3 * max_arrow_width , 3 * max_arrow_width ],
129+ }
130+ # for diagonal arrows, put the label at the arrow base
131+ # for vertical or horizontal arrows, center the label
132+ where = 'base' if (cp0 != cp1 ).all () else 'center'
133+ # rotate based on direction of arrow (cos, sin)
134+ M = [[cos , - sin ], [sin , cos ]]
135+ x , y = np .dot (M , orig_positions [where ]) + [x_pos , y_pos ]
136+ label = r'$r_{_{\mathrm{%s}}}$' % (pair ,)
137+ ax .text (x , y , label , size = label_text_size , ha = 'center' , va = 'center' ,
138+ color = labelcolor or fc )
275139
276140
277141if __name__ == '__main__' :
278- from sys import argv
279- d = None
280- if len (argv ) > 1 :
281- if argv [1 ] == 'full' :
282- d = all_on_max
283- scaled = False
284- elif argv [1 ] == 'extreme' :
285- d = extreme_data
286- scaled = False
287- elif argv [1 ] == 'realistic' :
288- d = realistic_data
289- scaled = False
290- elif argv [1 ] == 'sample' :
291- d = sample_data
292- scaled = True
293- if d is None :
294- d = all_on_max
295- scaled = False
296- if len (argv ) > 2 :
297- display = argv [2 ]
298- else :
299- display = 'length'
142+ data = { # test data
143+ 'A' : 0.4 , 'T' : 0.3 , 'G' : 0.6 , 'C' : 0.2 ,
144+ 'AT' : 0.4 , 'AC' : 0.3 , 'AG' : 0.2 ,
145+ 'TA' : 0.2 , 'TC' : 0.3 , 'TG' : 0.4 ,
146+ 'CT' : 0.2 , 'CG' : 0.3 , 'CA' : 0.2 ,
147+ 'GA' : 0.1 , 'GT' : 0.4 , 'GC' : 0.1 ,
148+ }
300149
301150 size = 4
302- plt .figure (figsize = (size , size ))
151+ fig = plt .figure (figsize = (3 * size , size ), constrained_layout = True )
152+ axs = fig .subplot_mosaic ([["length" , "width" , "alpha" ]])
303153
304- make_arrow_plot (d , display = display , linewidth = 0.001 , edgecolor = None ,
305- normalize_data = scaled , head_starts_at_zero = True , size = size )
154+ for display , ax in axs .items ():
155+ make_arrow_graph (
156+ ax , data , display = display , linewidth = 0.001 , edgecolor = None ,
157+ normalize_data = True , size = size )
306158
307159 plt .show ()
0 commit comments