Skip to content

Commit b5962f3

Browse files
committed
refractored conditional
1 parent ef602fa commit b5962f3

File tree

4 files changed

+280
-204
lines changed

4 files changed

+280
-204
lines changed

model/distributions/conditional/__init__.py

Whitespace-only changes.
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
from pathlib import Path
2+
import plotly
3+
import dash
4+
from dash import dcc, html, Input, Output, callback, Patch, callback_context
5+
import plotly.graph_objects as go
6+
import dash_bootstrap_components as dbc
7+
from numpy import sqrt, linspace, pi, sign, exp, square, array, diff, matmul, zeros, meshgrid
8+
from numpy.random import randint
9+
from numpy.linalg import det, solve
10+
11+
from components.popup_box import PopupBox
12+
from components.label import Label
13+
14+
15+
from model.selfcontained_distribution import SelfContainedDistribution
16+
17+
class Conditional(SelfContainedDistribution):
18+
def __init__(self):
19+
# Colors
20+
col_marginal = plotly.colors.qualitative.Plotly[0]
21+
col_conditional = plotly.colors.qualitative.Plotly[1]
22+
col_slice = plotly.colors.qualitative.Plotly[4]
23+
# axis limits
24+
rangx = [-4, 4]
25+
rangy = [-4, 4]
26+
# slider range
27+
smin = rangy[0]
28+
smax = rangy[1]
29+
# plot size relative to window size
30+
relwidth = 95
31+
relheight = round((relwidth/diff(rangx)*diff(rangy))[0])
32+
33+
# Grid
34+
self.xv = linspace(rangx[0], rangx[1], 100)
35+
self.yv = linspace(rangy[0], rangy[1], 100)
36+
self.xm, self.ym = meshgrid(self.xv, self.yv)
37+
38+
self.config = {
39+
'toImageButtonOptions': {
40+
'format': 'jpeg', # png, svg, pdf, jpeg, webp
41+
'width': None, # None: use currently-rendered size
42+
'height': None,
43+
'filename': 'conditional',
44+
},
45+
'responsive': True,
46+
'scrollZoom': True,
47+
}
48+
49+
self.fig = go.Figure(
50+
data=[
51+
go.Surface(
52+
name='Joint f(x,y)',
53+
x=self.xm,
54+
y=self.ym,
55+
z=self.xm*0,
56+
hoverinfo='skip',
57+
colorscale='Cividis',
58+
showlegend=True,
59+
showscale=False,
60+
reversescale=False,
61+
contours={
62+
"z": {
63+
"show": True,
64+
"start": 0.1,
65+
"end": 1,
66+
"size": 0.1,
67+
"width": 1,
68+
"color": "white"
69+
}
70+
}
71+
),
72+
go.Scatter3d(
73+
name='Marginal f(x)',
74+
x=self.xv,
75+
y=self.yv*0+self.yv[0],
76+
mode='lines',
77+
z=self.xv*0,
78+
marker_color=col_marginal,
79+
showlegend=True,
80+
hoverinfo='skip',
81+
line={'width': 10},
82+
# surfaceaxis=2 # wait for: https://github.com/plotly/plotly.js/issues/2352
83+
),
84+
go.Scatter3d(
85+
name='Conditional f(x|ŷ)',
86+
x=self.xv, y=self.yv*0+self.yv[0],
87+
mode='lines',
88+
z=self.xv*0,
89+
marker_color=col_conditional,
90+
showlegend=True,
91+
hoverinfo='skip',
92+
line={'width': 4},
93+
# surfaceaxis=2 # wait for: https://github.com/plotly/plotly.js/issues/2352
94+
),
95+
go.Scatter3d(
96+
name='Slice f(x,ŷ)',
97+
x=self.xv,
98+
y=self.yv*0+self.yv[0],
99+
mode='lines',
100+
z=self.xv*0,
101+
marker_color=col_slice,
102+
showlegend=True,
103+
hoverinfo='skip',
104+
line={'width': 8}
105+
),
106+
]
107+
)
108+
self.fig.update_xaxes(range=rangx, tickmode='array', tickvals=list(range(rangx[0], rangx[1]+1)))
109+
self.fig.update_yaxes(range=rangy, tickmode='array', tickvals=list(range(rangy[0], rangy[1]+1)), scaleanchor="x", scaleratio=1)
110+
self.fig.update_layout(transition_duration=100, transition_easing='linear')
111+
self.fig.update_scenes(camera_projection_type="orthographic")
112+
self.fig.update_scenes(aspectmode="auto")
113+
# fig.update_scenes(xaxis_nticks=1)
114+
# fig.update_scenes(yaxis_nticks=1)
115+
self.fig.update_scenes(zaxis_nticks=1)
116+
self.fig.update_layout(margin=dict(l=0, r=0, t=0, b=0, pad=0))
117+
118+
self.fig.update_layout(
119+
legend=dict(
120+
yanchor="top",
121+
y=0.98,
122+
xanchor="left",
123+
x=0.02,
124+
bgcolor="rgba(255,255,255,0.7)"
125+
)
126+
)
127+
128+
path = Path(__file__).parent / "info_text.md"
129+
with open(path, 'r') as f:
130+
self.info_text = dcc.Markdown(f.read(), mathjax=True)
131+
132+
self.settings_layout = [
133+
dbc.Container(
134+
dbc.Col([
135+
html.Br(),
136+
137+
# y Slider
138+
Label("Condition on ŷ",
139+
dcc.Slider(
140+
id="joint-y",
141+
min=smin,
142+
max=smax,
143+
value=randint(smin*10, smax*10)/10,
144+
updatemode='drag', marks=None,
145+
tooltip={"template": "ŷ={value}", "placement": "bottom", "always_visible": True}
146+
),
147+
),
148+
149+
# ρ Slider
150+
Label("Correlation ρ",
151+
dcc.Slider(
152+
id="joint-ρ",
153+
min=-1, max=1,
154+
value=randint(-9, 9)/10,
155+
updatemode='mouseup',
156+
marks=None,
157+
tooltip={"template": "ρ={value}", "placement": "bottom", "always_visible": True}
158+
)
159+
),
160+
161+
html.Hr(),
162+
html.Br(),
163+
164+
# Info Popup
165+
*PopupBox("joint-info", "Learn More", "Additional Information", self.info_text),
166+
167+
]),
168+
fluid=True,
169+
className="g-0"
170+
),
171+
]
172+
self.plot_layout = [
173+
dcc.Graph(id="joint-graph", figure=self.fig, config=self.config, style={'height': '100%'}),
174+
]
175+
176+
self._register_callbacks()
177+
178+
def _register_callbacks(self):
179+
@callback(
180+
Output("joint-graph", "figure"),
181+
Input("joint-y", "value"),
182+
Input("joint-ρ", "value"),
183+
)
184+
def update(ys, ρ):
185+
patched_fig = Patch()
186+
# Joint Parameters
187+
μ = zeros([2, 1])
188+
sx = 1
189+
sy = 1
190+
# TODO special treatment for singular density
191+
ρ = sign(ρ) * min(abs(ρ), .9999)
192+
C = array([[sx**2, sx*sy*ρ], [sx*sy*ρ, sy**2]])
193+
# Marginal Parameters
194+
µMarginal = µ[0]
195+
CMarginal = C[0, 0]
196+
marginal_fac = 1 / self.gauss1(0, 0, CMarginal)
197+
# Density has been modified?
198+
if (callback_context.triggered_id == "joint-ρ") | (callback_context.triggered_id is None):
199+
zMarginal = self.gauss1(self.xv, µMarginal, CMarginal)
200+
patched_fig['data'][1]['z'] = zMarginal * marginal_fac
201+
# Compute new joint density values
202+
# TODO should be more elegant than 2 for loops
203+
zJoint = self.xm*0
204+
for i in range(self.xm.shape[0]):
205+
for j in range(self.xm.shape[1]):
206+
zJoint[i, j] = self.gauss2(self.xm[i, j], self.ym[i, j], μ, C)
207+
zJoint = zJoint / self.gauss2(0, 0, zeros([2, 1]), C) # rescale to height 1
208+
patched_fig['data'][0]['z'] = zJoint
209+
# Compute Conditional
210+
# https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
211+
µCond = µ[0] + C[0, 1] / C[1, 1] * (ys-µ[1])
212+
CCond = C[0, 0] - C[0, 1] / C[1, 1] * C[0, 1]
213+
zCond = self.gauss1(self.xv, µCond, CCond)
214+
zSlice = zCond / self.gauss1(0, 0, CCond) * self.gauss1(ys, µ[1], C[1, 1]) / self.gauss1(0, 0, C[1, 1])
215+
# Plot Conditional
216+
patched_fig['data'][2]['z'] = zCond * marginal_fac
217+
# Plot Joint Slice
218+
patched_fig['data'][3]['y'] = self.xv*0+ys
219+
patched_fig['data'][3]['z'] = zSlice + 1e-3
220+
return patched_fig
221+
222+
@staticmethod
223+
def gauss1(x, μ, C):
224+
return 1/sqrt(2*pi*C) * exp(-1/2 * square((x-μ))/C)
225+
226+
227+
@staticmethod
228+
def gauss2(x, y, μ, C):
229+
d = array([x-μ[0], y-μ[1]])
230+
d = d.reshape(-1, 1) # to column vector
231+
f = 1/sqrt(det(2*pi*C)) * exp(-1/2 * matmul(d.T, solve(C, d)))
232+
return f[0][0]
233+
234+
235+
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
## 2D Gaussian
2+
Interactive visualizaton of the 2D Gaussian and its marginal and conditional density.
3+
4+
$$
5+
f(\underline x) = \mathcal{N}(\underline x; \underline \mu, \textbf{C}) =
6+
\frac{1}{2\pi \sqrt{\det(\textbf{C})}}
7+
\cdot \exp\!\left\{ -\frac{1}{2}
8+
\cdot (\underline x - \underline \mu)^\top \textbf{C}^{-1} (\underline x - \underline \mu) \right\} \enspace,
9+
\quad \underline{x}\in \mathbb{R}^2 \enspace, \quad \textbf{C} \enspace \text{positive semidefinite} \enspace.
10+
$$
11+
12+
### Formulas and Literature
13+
The Gaussian parameters are restricted to
14+
$$
15+
\underline \mu = \begin{bmatrix}0 \\ 0\end{bmatrix}\,, \quad
16+
\textbf{C} = \begin{bmatrix}1 & \rho \\ \rho & 1\end{bmatrix} \enspace.
17+
$$
18+
19+
Formulas for marginalization and conditioning of are given in the
20+
[[MatrixCookbook](https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf)].
21+
22+
Note that the 1D and 2D densities are scaled with respect to each other such that 2D joint and 1D marginal have
23+
the same height and therefore the same shape when looking on the x-z plane.
24+
25+
26+
### Interactivity
27+
- GUI
28+
- rotate: left mouse click
29+
- pan: right mouse click
30+
- zoom: mouse wheel
31+
- add/remove lines: click in legend
32+
- value in state space (slider)
33+
- value to condition on $\hat{y}$
34+
- density parameter (slider)
35+
- correlation coefficient $\rho$

0 commit comments

Comments
 (0)