|
| 1 | +from pathlib import Path |
| 2 | +import plotly |
| 3 | +import dash |
| 4 | +from dash import dcc, html, Input, Output, callback, Patch, callback_context |
| 5 | +import plotly.graph_objects as go |
| 6 | +import dash_bootstrap_components as dbc |
| 7 | +from numpy import sqrt, linspace, pi, sign, exp, square, array, diff, matmul, zeros, meshgrid |
| 8 | +from numpy.random import randint |
| 9 | +from numpy.linalg import det, solve |
| 10 | + |
| 11 | +from components.popup_box import PopupBox |
| 12 | +from components.label import Label |
| 13 | + |
| 14 | + |
| 15 | +from model.selfcontained_distribution import SelfContainedDistribution |
| 16 | + |
| 17 | +class Conditional(SelfContainedDistribution): |
| 18 | + def __init__(self): |
| 19 | + # Colors |
| 20 | + col_marginal = plotly.colors.qualitative.Plotly[0] |
| 21 | + col_conditional = plotly.colors.qualitative.Plotly[1] |
| 22 | + col_slice = plotly.colors.qualitative.Plotly[4] |
| 23 | + # axis limits |
| 24 | + rangx = [-4, 4] |
| 25 | + rangy = [-4, 4] |
| 26 | + # slider range |
| 27 | + smin = rangy[0] |
| 28 | + smax = rangy[1] |
| 29 | + # plot size relative to window size |
| 30 | + relwidth = 95 |
| 31 | + relheight = round((relwidth/diff(rangx)*diff(rangy))[0]) |
| 32 | + |
| 33 | + # Grid |
| 34 | + self.xv = linspace(rangx[0], rangx[1], 100) |
| 35 | + self.yv = linspace(rangy[0], rangy[1], 100) |
| 36 | + self.xm, self.ym = meshgrid(self.xv, self.yv) |
| 37 | + |
| 38 | + self.config = { |
| 39 | + 'toImageButtonOptions': { |
| 40 | + 'format': 'jpeg', # png, svg, pdf, jpeg, webp |
| 41 | + 'width': None, # None: use currently-rendered size |
| 42 | + 'height': None, |
| 43 | + 'filename': 'conditional', |
| 44 | + }, |
| 45 | + 'responsive': True, |
| 46 | + 'scrollZoom': True, |
| 47 | + } |
| 48 | + |
| 49 | + self.fig = go.Figure( |
| 50 | + data=[ |
| 51 | + go.Surface( |
| 52 | + name='Joint f(x,y)', |
| 53 | + x=self.xm, |
| 54 | + y=self.ym, |
| 55 | + z=self.xm*0, |
| 56 | + hoverinfo='skip', |
| 57 | + colorscale='Cividis', |
| 58 | + showlegend=True, |
| 59 | + showscale=False, |
| 60 | + reversescale=False, |
| 61 | + contours={ |
| 62 | + "z": { |
| 63 | + "show": True, |
| 64 | + "start": 0.1, |
| 65 | + "end": 1, |
| 66 | + "size": 0.1, |
| 67 | + "width": 1, |
| 68 | + "color": "white" |
| 69 | + } |
| 70 | + } |
| 71 | + ), |
| 72 | + go.Scatter3d( |
| 73 | + name='Marginal f(x)', |
| 74 | + x=self.xv, |
| 75 | + y=self.yv*0+self.yv[0], |
| 76 | + mode='lines', |
| 77 | + z=self.xv*0, |
| 78 | + marker_color=col_marginal, |
| 79 | + showlegend=True, |
| 80 | + hoverinfo='skip', |
| 81 | + line={'width': 10}, |
| 82 | + # surfaceaxis=2 # wait for: https://github.com/plotly/plotly.js/issues/2352 |
| 83 | + ), |
| 84 | + go.Scatter3d( |
| 85 | + name='Conditional f(x|ŷ)', |
| 86 | + x=self.xv, y=self.yv*0+self.yv[0], |
| 87 | + mode='lines', |
| 88 | + z=self.xv*0, |
| 89 | + marker_color=col_conditional, |
| 90 | + showlegend=True, |
| 91 | + hoverinfo='skip', |
| 92 | + line={'width': 4}, |
| 93 | + # surfaceaxis=2 # wait for: https://github.com/plotly/plotly.js/issues/2352 |
| 94 | + ), |
| 95 | + go.Scatter3d( |
| 96 | + name='Slice f(x,ŷ)', |
| 97 | + x=self.xv, |
| 98 | + y=self.yv*0+self.yv[0], |
| 99 | + mode='lines', |
| 100 | + z=self.xv*0, |
| 101 | + marker_color=col_slice, |
| 102 | + showlegend=True, |
| 103 | + hoverinfo='skip', |
| 104 | + line={'width': 8} |
| 105 | + ), |
| 106 | + ] |
| 107 | + ) |
| 108 | + self.fig.update_xaxes(range=rangx, tickmode='array', tickvals=list(range(rangx[0], rangx[1]+1))) |
| 109 | + self.fig.update_yaxes(range=rangy, tickmode='array', tickvals=list(range(rangy[0], rangy[1]+1)), scaleanchor="x", scaleratio=1) |
| 110 | + self.fig.update_layout(transition_duration=100, transition_easing='linear') |
| 111 | + self.fig.update_scenes(camera_projection_type="orthographic") |
| 112 | + self.fig.update_scenes(aspectmode="auto") |
| 113 | + # fig.update_scenes(xaxis_nticks=1) |
| 114 | + # fig.update_scenes(yaxis_nticks=1) |
| 115 | + self.fig.update_scenes(zaxis_nticks=1) |
| 116 | + self.fig.update_layout(margin=dict(l=0, r=0, t=0, b=0, pad=0)) |
| 117 | + |
| 118 | + self.fig.update_layout( |
| 119 | + legend=dict( |
| 120 | + yanchor="top", |
| 121 | + y=0.98, |
| 122 | + xanchor="left", |
| 123 | + x=0.02, |
| 124 | + bgcolor="rgba(255,255,255,0.7)" |
| 125 | + ) |
| 126 | + ) |
| 127 | + |
| 128 | + path = Path(__file__).parent / "info_text.md" |
| 129 | + with open(path, 'r') as f: |
| 130 | + self.info_text = dcc.Markdown(f.read(), mathjax=True) |
| 131 | + |
| 132 | + self.settings_layout = [ |
| 133 | + dbc.Container( |
| 134 | + dbc.Col([ |
| 135 | + html.Br(), |
| 136 | + |
| 137 | + # y Slider |
| 138 | + Label("Condition on ŷ", |
| 139 | + dcc.Slider( |
| 140 | + id="joint-y", |
| 141 | + min=smin, |
| 142 | + max=smax, |
| 143 | + value=randint(smin*10, smax*10)/10, |
| 144 | + updatemode='drag', marks=None, |
| 145 | + tooltip={"template": "ŷ={value}", "placement": "bottom", "always_visible": True} |
| 146 | + ), |
| 147 | + ), |
| 148 | + |
| 149 | + # ρ Slider |
| 150 | + Label("Correlation ρ", |
| 151 | + dcc.Slider( |
| 152 | + id="joint-ρ", |
| 153 | + min=-1, max=1, |
| 154 | + value=randint(-9, 9)/10, |
| 155 | + updatemode='mouseup', |
| 156 | + marks=None, |
| 157 | + tooltip={"template": "ρ={value}", "placement": "bottom", "always_visible": True} |
| 158 | + ) |
| 159 | + ), |
| 160 | + |
| 161 | + html.Hr(), |
| 162 | + html.Br(), |
| 163 | + |
| 164 | + # Info Popup |
| 165 | + *PopupBox("joint-info", "Learn More", "Additional Information", self.info_text), |
| 166 | + |
| 167 | + ]), |
| 168 | + fluid=True, |
| 169 | + className="g-0" |
| 170 | + ), |
| 171 | + ] |
| 172 | + self.plot_layout = [ |
| 173 | + dcc.Graph(id="joint-graph", figure=self.fig, config=self.config, style={'height': '100%'}), |
| 174 | + ] |
| 175 | + |
| 176 | + self._register_callbacks() |
| 177 | + |
| 178 | + def _register_callbacks(self): |
| 179 | + @callback( |
| 180 | + Output("joint-graph", "figure"), |
| 181 | + Input("joint-y", "value"), |
| 182 | + Input("joint-ρ", "value"), |
| 183 | + ) |
| 184 | + def update(ys, ρ): |
| 185 | + patched_fig = Patch() |
| 186 | + # Joint Parameters |
| 187 | + μ = zeros([2, 1]) |
| 188 | + sx = 1 |
| 189 | + sy = 1 |
| 190 | + # TODO special treatment for singular density |
| 191 | + ρ = sign(ρ) * min(abs(ρ), .9999) |
| 192 | + C = array([[sx**2, sx*sy*ρ], [sx*sy*ρ, sy**2]]) |
| 193 | + # Marginal Parameters |
| 194 | + µMarginal = µ[0] |
| 195 | + CMarginal = C[0, 0] |
| 196 | + marginal_fac = 1 / self.gauss1(0, 0, CMarginal) |
| 197 | + # Density has been modified? |
| 198 | + if (callback_context.triggered_id == "joint-ρ") | (callback_context.triggered_id is None): |
| 199 | + zMarginal = self.gauss1(self.xv, µMarginal, CMarginal) |
| 200 | + patched_fig['data'][1]['z'] = zMarginal * marginal_fac |
| 201 | + # Compute new joint density values |
| 202 | + # TODO should be more elegant than 2 for loops |
| 203 | + zJoint = self.xm*0 |
| 204 | + for i in range(self.xm.shape[0]): |
| 205 | + for j in range(self.xm.shape[1]): |
| 206 | + zJoint[i, j] = self.gauss2(self.xm[i, j], self.ym[i, j], μ, C) |
| 207 | + zJoint = zJoint / self.gauss2(0, 0, zeros([2, 1]), C) # rescale to height 1 |
| 208 | + patched_fig['data'][0]['z'] = zJoint |
| 209 | + # Compute Conditional |
| 210 | + # https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf |
| 211 | + µCond = µ[0] + C[0, 1] / C[1, 1] * (ys-µ[1]) |
| 212 | + CCond = C[0, 0] - C[0, 1] / C[1, 1] * C[0, 1] |
| 213 | + zCond = self.gauss1(self.xv, µCond, CCond) |
| 214 | + zSlice = zCond / self.gauss1(0, 0, CCond) * self.gauss1(ys, µ[1], C[1, 1]) / self.gauss1(0, 0, C[1, 1]) |
| 215 | + # Plot Conditional |
| 216 | + patched_fig['data'][2]['z'] = zCond * marginal_fac |
| 217 | + # Plot Joint Slice |
| 218 | + patched_fig['data'][3]['y'] = self.xv*0+ys |
| 219 | + patched_fig['data'][3]['z'] = zSlice + 1e-3 |
| 220 | + return patched_fig |
| 221 | + |
| 222 | + @staticmethod |
| 223 | + def gauss1(x, μ, C): |
| 224 | + return 1/sqrt(2*pi*C) * exp(-1/2 * square((x-μ))/C) |
| 225 | + |
| 226 | + |
| 227 | + @staticmethod |
| 228 | + def gauss2(x, y, μ, C): |
| 229 | + d = array([x-μ[0], y-μ[1]]) |
| 230 | + d = d.reshape(-1, 1) # to column vector |
| 231 | + f = 1/sqrt(det(2*pi*C)) * exp(-1/2 * matmul(d.T, solve(C, d))) |
| 232 | + return f[0][0] |
| 233 | + |
| 234 | + |
| 235 | + |
0 commit comments