|
| 1 | +from pathlib import Path |
| 2 | +import plotly |
| 3 | +import dash |
| 4 | +import pandas |
| 5 | +from urllib.error import HTTPError |
| 6 | +from dash import dcc, html, Input, Output, callback, Patch |
| 7 | +import plotly.graph_objects as go |
| 8 | +import dash_bootstrap_components as dbc |
| 9 | +from numpy import sqrt, linspace, vstack, hstack, pi, nan, full, exp, square, arange, array, sin, cos, diff, matmul, log10, deg2rad, identity, ones, zeros, diag, cov, mean |
| 10 | +from numpy.random import randn, randint |
| 11 | +from numpy.linalg import cholesky, eig, det, inv |
| 12 | +from scipy.special import erfinv |
| 13 | + |
| 14 | +from components.popup_box import PopupBox |
| 15 | +from model.selfcontained_distribution import SelfContainedDistribution |
| 16 | + |
| 17 | +# Samples Library |
| 18 | +# technically, this has state, but its fine because its just a cache |
| 19 | +data_dict = {} |
| 20 | + |
| 21 | +# Get Samples from Library (and load if not available) |
| 22 | +def get_data(url): |
| 23 | + if not (url in data_dict): |
| 24 | + try: |
| 25 | + data = pandas.read_csv(url, header=None).to_numpy() |
| 26 | + except HTTPError: |
| 27 | + # URL doesn't exist |
| 28 | + data = full([2, 0], nan) |
| 29 | + data_dict[url] = data |
| 30 | + return data_dict[url] |
| 31 | + |
| 32 | +class Gaus2D(SelfContainedDistribution): |
| 33 | + def __init__(self): |
| 34 | + self.smethods = ['iid', 'Fibonacci', 'LCD', 'SP-Julier04', 'SP-Menegaz11'] # Sampling methods |
| 35 | + self.tmethods = ['Cholesky', 'Eigendecomposition'] # Transformation methods |
| 36 | + |
| 37 | + # Colors |
| 38 | + self.col_density = plotly.colors.qualitative.Plotly[1] |
| 39 | + self.col_samples = plotly.colors.qualitative.Plotly[0] |
| 40 | + |
| 41 | + self.config = { |
| 42 | + 'toImageButtonOptions': { |
| 43 | + 'format': 'jpeg', # png, svg, pdf, jpeg, webp |
| 44 | + 'height': None, # None: use currently-rendered size |
| 45 | + 'width': None, |
| 46 | + 'filename': 'gauss2d', |
| 47 | + }, |
| 48 | + 'scrollZoom': True, |
| 49 | + } |
| 50 | + |
| 51 | + # axis limits |
| 52 | + self.rangx = [-5, 5] |
| 53 | + self.rangy = [-4, 4] |
| 54 | + |
| 55 | + # plot size relative to window size |
| 56 | + relwidth = 95 |
| 57 | + self.relheight = round((relwidth/diff(self.rangx)*diff(self.rangy))[0]) |
| 58 | + |
| 59 | + # Gauss ellipse |
| 60 | + s = linspace(0, 2*pi, 500) |
| 61 | + self.circ = vstack((cos(s), sin(s))) * 2 |
| 62 | + |
| 63 | + self.fig = go.Figure( |
| 64 | + data=[ |
| 65 | + go.Scattergl(name='Density', |
| 66 | + x=[0], |
| 67 | + y=[0], |
| 68 | + mode='lines', |
| 69 | + marker_color=self.col_density, |
| 70 | + showlegend=True, |
| 71 | + hoverinfo='skip', |
| 72 | + line={'width': 3}, |
| 73 | + line_shape='linear', |
| 74 | + fill='tozerox' |
| 75 | + ), |
| 76 | + go.Scattergl( |
| 77 | + name='Samples', |
| 78 | + x=[0], |
| 79 | + y=[0], |
| 80 | + mode='markers', |
| 81 | + marker_color=self.col_samples, |
| 82 | + marker_line_color='black', |
| 83 | + marker_opacity=1, |
| 84 | + showlegend=True |
| 85 | + ) |
| 86 | + ], |
| 87 | + ) |
| 88 | + self.fig.update_xaxes(range=self.rangx, tickmode='array', tickvals=list(range(self.rangx[0], self.rangx[1]+1))) |
| 89 | + self.fig.update_yaxes(range=self.rangy, tickmode='array', tickvals=list(range(self.rangy[0], self.rangy[1]+1)), scaleanchor="x", scaleratio=1) |
| 90 | + self.fig.update_layout(legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1)) |
| 91 | + self.fig.update_layout(modebar_add=['drawopenpath', 'eraseshape'], newshape_line_color='cyan', dragmode='pan') |
| 92 | + |
| 93 | + self.fig.update_layout( |
| 94 | + legend=dict( |
| 95 | + orientation="v", |
| 96 | + xanchor="right", |
| 97 | + x=0.1, |
| 98 | + ) |
| 99 | + ) |
| 100 | + |
| 101 | + path = Path(__file__).parent / "info_text.md" |
| 102 | + with open(path, 'r') as f: |
| 103 | + self.info_text = dcc.Markdown(f.read(), mathjax=True) |
| 104 | + |
| 105 | + self.settings_layout = [ |
| 106 | + dbc.Container( |
| 107 | + dbc.Col([ |
| 108 | + html.P("Select Sampling Method:"), |
| 109 | + html.Br(), |
| 110 | + |
| 111 | + # Sampling Strategy RadioItems |
| 112 | + dbc.RadioItems(id='gauss2D-smethod', |
| 113 | + options=[{"label": x, "value": x} for x in self.smethods], |
| 114 | + value=self.smethods[randint(len(self.smethods))], |
| 115 | + inline=True), |
| 116 | + |
| 117 | + # Transformation Method RadioItems |
| 118 | + dbc.RadioItems(id='gauss2D-tmethod', |
| 119 | + options=[{"label": x, "value": x} for x in self.tmethods], |
| 120 | + value=self.tmethods[randint(len(self.tmethods))], |
| 121 | + inline=True), |
| 122 | + |
| 123 | + html.Br(), |
| 124 | + html.Hr(), |
| 125 | + html.Br(), |
| 126 | + |
| 127 | + # param Slider |
| 128 | + dcc.Slider(id="gauss2D-p", min=0, max=1, value=randint(3, 7)/10, updatemode='drag', marks=None, |
| 129 | + tooltip={"template": "p={value}", "placement": "bottom", "always_visible": True}), |
| 130 | + |
| 131 | + # L Slider |
| 132 | + dcc.Slider(id="gauss2D-L", min=log10(1.2), max=4.001, step=0.001, value=2, updatemode='drag', marks=None, # persistence=True, |
| 133 | + tooltip={"template": "L={value}", "placement": "bottom", "always_visible": True, "transform": "trafo_L"}), |
| 134 | + |
| 135 | + # σ Slider |
| 136 | + dcc.Slider(id="gauss2D-σx", min=0, max=5, step=0.01, value=1, updatemode='drag', marks=None, |
| 137 | + tooltip={"template": 'σx={value}', "placement": "bottom", "always_visible": True}), |
| 138 | + dcc.Slider(id="gauss2D-σy", min=0, max=5, step=0.01, value=1, updatemode='drag', marks=None, |
| 139 | + tooltip={"template": 'σy={value}', "placement": "bottom", "always_visible": True}), |
| 140 | + |
| 141 | + # ρ Slider |
| 142 | + dcc.Slider(id="gauss2D-ρ", min=-1, max=1, step=0.001, value=0, updatemode='drag', marks=None, |
| 143 | + tooltip={"template": 'ρ={value}', "placement": "bottom", "always_visible": True}), |
| 144 | + |
| 145 | + html.Hr(), |
| 146 | + html.Br(), |
| 147 | + |
| 148 | + |
| 149 | + # Info Popup |
| 150 | + *PopupBox("gauss2D-info", "Learn More", "Additional Information", self.info_text), |
| 151 | + ]), |
| 152 | + fluid=True, |
| 153 | + className="g-0") |
| 154 | + ] |
| 155 | + |
| 156 | + self.plot_layout = [ |
| 157 | + dcc.Graph(id="gauss2D-graph", figure=self.fig, config=self.config, style={'height': '100%'}), |
| 158 | + ] |
| 159 | + |
| 160 | + self._register_callbacks() |
| 161 | + |
| 162 | + def _register_callbacks(self): |
| 163 | + @callback( |
| 164 | + Output('gauss2D-p', 'min'), |
| 165 | + Output('gauss2D-p', 'max'), |
| 166 | + Output('gauss2D-p', 'value'), |
| 167 | + Output('gauss2D-p', 'step'), |
| 168 | + Output('gauss2D-p', 'tooltip'), |
| 169 | + Output('gauss2D-L', 'disabled'), |
| 170 | + Input("gauss2D-smethod", "value"), |
| 171 | + ) |
| 172 | + def update_smethod(smethod): |
| 173 | + patched_tooltip = Patch() |
| 174 | + match smethod: |
| 175 | + case 'iid': |
| 176 | + patched_tooltip.template = "dice" |
| 177 | + # min, max, value, step, tooltip |
| 178 | + return 0, 1, .5, 0.001, patched_tooltip, False |
| 179 | + case 'Fibonacci': |
| 180 | + patched_tooltip.template = "z={value}" |
| 181 | + return -50, 50, 0, 1, patched_tooltip, False |
| 182 | + case 'LCD': |
| 183 | + patched_tooltip.template = "α={value}°" |
| 184 | + return -360, 360, 0, 0.1, patched_tooltip, False |
| 185 | + case 'SP-Julier04': |
| 186 | + patched_tooltip.template = "W₀={value}" |
| 187 | + return -2, 1, .1, 0.001, patched_tooltip, True |
| 188 | + case 'SP-Menegaz11': |
| 189 | + patched_tooltip.template = "Wₙ₊₁={value}" |
| 190 | + return 0, 1, 1/3, 0.001, patched_tooltip, True |
| 191 | + case _: |
| 192 | + raise Exception("Wrong smethod") |
| 193 | + |
| 194 | + |
| 195 | + @callback( |
| 196 | + Output("gauss2D-graph", "figure"), |
| 197 | + Input("gauss2D-smethod", "value"), |
| 198 | + Input("gauss2D-tmethod", "value"), |
| 199 | + Input("gauss2D-p", "value"), |
| 200 | + Input("gauss2D-L", "value"), |
| 201 | + Input("gauss2D-σx", "value"), |
| 202 | + Input("gauss2D-σy", "value"), |
| 203 | + Input("gauss2D-ρ", "value"), |
| 204 | + ) |
| 205 | + def update(smethod, tmethod, p, L0, σx, σy, ρ): |
| 206 | + # Slider Transform, |
| 207 | + L = self.trafo_L(L0) |
| 208 | + # Mean |
| 209 | + # μ = array([[μx], [μy]]) |
| 210 | + μ = array([[0], [0]]) |
| 211 | + # Covariance |
| 212 | + C = array([[square(σx), σx*σy*ρ], [σx*σy*ρ, square(σy)]]) |
| 213 | + C_D, C_R = eig(C) |
| 214 | + C_D = C_D[..., None] # to column vector |
| 215 | + |
| 216 | + patched_fig = Patch() |
| 217 | + # Draw SND |
| 218 | + weights = None |
| 219 | + match smethod: |
| 220 | + case 'iid': |
| 221 | + xySND = randn(2, L) |
| 222 | + case 'Fibonacci': |
| 223 | + # TODO 2nd parameter |
| 224 | + xUni = (sqrt(5)-1)/2 * (arange(L)+1+round(p)) % 1 |
| 225 | + yUni = (2*arange(L)+1)/(2*L) # +p |
| 226 | + xyUni = vstack((xUni, yUni)) |
| 227 | + xySND = sqrt(2)*erfinv(2*xyUni-1) |
| 228 | + case 'LCD': |
| 229 | + xySND = get_data(self.url_SND_LCD(2, L)) |
| 230 | + xySND = matmul(self.rot(p), xySND) |
| 231 | + case 'SP-Julier04': |
| 232 | + # https://ieeexplore.ieee.org/abstract/document/1271397 |
| 233 | + Nx = 2 # dimension |
| 234 | + x0 = zeros([Nx, 1]) |
| 235 | + W0 = full([1, 1], p) # parameter, W0<1 |
| 236 | + x1 = sqrt(Nx/(1-W0) * identity(Nx)) |
| 237 | + W1 = full([1, Nx], (1-W0)/(2*Nx)) |
| 238 | + x2 = -x1 |
| 239 | + W2 = W1 |
| 240 | + xySND = hstack((x0, x1, x2)) |
| 241 | + weights = hstack((W0, W1, W2)) |
| 242 | + case 'SP-Menegaz11': |
| 243 | + # https://ieeexplore.ieee.org/abstract/document/6161480 |
| 244 | + n = 2 # dimension |
| 245 | + w0 = p # parameter, 0<w0<1 |
| 246 | + α = sqrt((1-w0)/n) |
| 247 | + CC2 = identity(n) - α**2 |
| 248 | + CC = cholesky(CC2) |
| 249 | + w1 = diag(w0 * α**2 * matmul(matmul(inv(CC), ones([n, n])), inv(CC.T))) |
| 250 | + x0 = full([n, 1], -α/sqrt(w0)) |
| 251 | + x1 = matmul(CC, inv(identity(n)*sqrt(w1))) |
| 252 | + # x1 = CC / sqrt(W1) |
| 253 | + xySND = hstack((x0, x1)) |
| 254 | + weights = hstack((p, w1)) |
| 255 | + case _: |
| 256 | + raise Exception("Wrong smethod") |
| 257 | + match tmethod: |
| 258 | + case 'Cholesky': |
| 259 | + xyG = matmul(cholesky(C), xySND) + μ |
| 260 | + case 'Eigendecomposition': |
| 261 | + xyG = matmul(C_R, sqrt(C_D) * xySND) + μ |
| 262 | + case _: |
| 263 | + raise Exception("Wrong smethod") |
| 264 | + # Sample weights to scatter sizes |
| 265 | + L2 = xySND.shape[1] # actual number of saamples |
| 266 | + if L2 == 0: |
| 267 | + sizes = 10 |
| 268 | + else: |
| 269 | + if weights is None: |
| 270 | + weights = 1/L2 # equally weighted |
| 271 | + else: |
| 272 | + weights = weights.flatten() |
| 273 | + sizes = sqrt(abs(weights) * L2) * det(2*pi*C)**(1/4) / sqrt(L2) * 70 |
| 274 | + # print(hstack((cov(xyG, bias=True, aweights=weights), C))) |
| 275 | + # Plot Ellipse |
| 276 | + elp = matmul(C_R, sqrt(C_D) * self.circ) + μ |
| 277 | + patched_fig['data'][0]['x'] = elp[0, :] |
| 278 | + patched_fig['data'][0]['y'] = elp[1, :] |
| 279 | + # Plot Samples |
| 280 | + patched_fig['data'][1]['x'] = xyG[0, :] |
| 281 | + patched_fig['data'][1]['y'] = xyG[1, :] |
| 282 | + patched_fig['data'][1]['marker']['size'] = sizes |
| 283 | + patched_fig['data'][1]['marker']['line']['width'] = sizes/20 |
| 284 | + return patched_fig |
| 285 | + |
| 286 | + |
| 287 | + @staticmethod |
| 288 | + def url_SND_LCD(D, L): |
| 289 | + return f'https://raw.githubusercontent.com/KIT-ISAS/deterministic-samples-csv/main/standard-normal/glcd/D{D}-N{L}.csv' |
| 290 | + |
| 291 | + @staticmethod |
| 292 | + def gauss1(x, μ, σ): |
| 293 | + return 1/sqrt(2*pi*σ) * exp(-1/2 * square((x-μ)/σ)) |
| 294 | + |
| 295 | + |
| 296 | + # Slider Transform, must be idencital to window.dccFunctions.trafo_L in assets/tooltip.js |
| 297 | + @staticmethod |
| 298 | + def trafo_L(L0): |
| 299 | + if L0 < log10(1.25): |
| 300 | + return 0 |
| 301 | + else: |
| 302 | + return round(10 ** L0) |
| 303 | + |
| 304 | + @staticmethod |
| 305 | + def rot(a): |
| 306 | + ar = deg2rad(a) |
| 307 | + return array([[cos(ar), -sin(ar)], [sin(ar), cos(ar)]]) |
| 308 | + |
| 309 | + |
| 310 | + |
| 311 | + |
0 commit comments