2
2
import numpy as np
3
3
from scipy .stats import kde
4
4
5
- __all__ = ['traceplot' ]
5
+ __all__ = ['traceplot' , 'kdeplot' , 'kde2plot' ]
6
6
7
7
def traceplot (trace , vars = None ):
8
8
if vars is None :
@@ -14,7 +14,7 @@ def traceplot(trace, vars=None):
14
14
for i ,v in enumerate (vars ):
15
15
d = np .squeeze (trace [v ])
16
16
17
- kdeplot (ax [0 ,i ], d )
17
+ kdeplot_op (ax [0 ,i ], d )
18
18
ax [0 ,i ].set_title (str (v ))
19
19
ax [1 ,i ].plot (d , alpha = .35 )
20
20
@@ -23,7 +23,7 @@ def traceplot(trace, vars=None):
23
23
24
24
return f
25
25
26
- def kdeplot (ax , data ):
26
+ def kdeplot_op (ax , data ):
27
27
data = np .atleast_2d (data .T ).T
28
28
for i in range (data .shape [1 ]):
29
29
d = data [:,i ]
@@ -33,6 +33,32 @@ def kdeplot(ax, data):
33
33
x = np .linspace (0 ,1 ,100 )* (u - l )+ l
34
34
35
35
ax .plot (x ,density (x ))
36
+
37
+ def kde2plot_op (ax , x , y , grid = 200 ):
38
+ xmin = x .min ()
39
+ xmax = x .max ()
40
+ ymin = y .min ()
41
+ ymax = y .max ()
42
+
43
+ grid = grid * 1j
44
+ X , Y = np .mgrid [xmin :xmax :grid , ymin :ymax :grid ]
45
+ positions = np .vstack ([X .ravel (), Y .ravel ()])
46
+ values = np .vstack ([x , y ])
47
+ kernel = kde .gaussian_kde (values )
48
+ Z = np .reshape (kernel (positions ).T , X .shape )
49
+
50
+ ax .imshow (np .rot90 (Z ), cmap = p .cm .gist_earth_r ,
51
+ extent = [xmin , xmax , ymin , ymax ])
52
+
53
+ def kdeplot (data ):
54
+ f , ax = p .subplots (1 , 1 , squeeze = True )
55
+ kdeplot_op (ax , data )
56
+ return f
57
+
58
+ def kde2plot (x ,y , grid = 200 ):
59
+ f , ax = p .subplots (1 , 1 , squeeze = True )
60
+ kde2plot_op (ax , x ,y , grid )
61
+ return f
36
62
37
63
38
64
0 commit comments