|
| 1 | +import numpy as np |
| 2 | +import xarray as xr |
| 3 | +from typing import Union, List, Dict, Optional, Tuple |
| 4 | +import PyCO2SYS as pyco2 |
| 5 | + |
| 6 | + |
| 7 | +from .utils import calculate_decimal_year, adjust_arctic_latitude, load_weight_file |
| 8 | + |
| 9 | + |
| 10 | +def canyonb( |
| 11 | + gtime: Union[np.ndarray, List], |
| 12 | + lat: np.ndarray, |
| 13 | + lon: np.ndarray, |
| 14 | + pres: np.ndarray, |
| 15 | + temp: np.ndarray, |
| 16 | + psal: np.ndarray, |
| 17 | + doxy: np.ndarray, |
| 18 | + param: Optional[List[str]] = None, |
| 19 | + epres: Optional[float] = 0.5, |
| 20 | + etemp: Optional[float] = 0.005, |
| 21 | + epsal: Optional[float] = 0.005, |
| 22 | + edoxy: Optional[Union[float, np.ndarray]] = None, |
| 23 | + weights_dir: str = None |
| 24 | +) -> Dict[str, xr.DataArray]: |
| 25 | + # TODO order='F' should be checked if needed for 2d-4d arrays as inputs |
| 26 | + # Using xarray: |
| 27 | + # For now 2d-matrix not handle because we are creating xarray matrix at the end with multiple dimension wihtout multiple dimensions names. Should address this issue. |
| 28 | + # This might be needed but not sure since all reshape, flatten and eisum are order='C'(or 'K') for indexing right now. |
| 29 | + """ |
| 30 | + CANYON-B neural network prediction for ocean parameters. |
| 31 | + |
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + gtime : array-like |
| 35 | + Date (UTC) as datetime objects or decimal years |
| 36 | + lat : array-like |
| 37 | + Latitude (-90 to 90) |
| 38 | + lon : array-like |
| 39 | + Longitude (-180 to 180 or 0 to 360) |
| 40 | + pres : array-like |
| 41 | + Pressure (dbar) |
| 42 | + temp : array-like |
| 43 | + In-situ temperature (°C) |
| 44 | + psal : array-like |
| 45 | + Salinity |
| 46 | + doxy : array-like |
| 47 | + Dissolved oxygen (µmol/kg) |
| 48 | + param : list of str, optional |
| 49 | + Parameters to calculate. Default calculates all. |
| 50 | + epres, etemp, epsal : float, optional |
| 51 | + Input errors |
| 52 | + edoxy : float or array-like, optional |
| 53 | + Oxygen input error (default: 1% of doxy) |
| 54 | + weights_dir : str |
| 55 | + Directory containing weight files |
| 56 | + |
| 57 | + Returns |
| 58 | + ------- |
| 59 | + Dict[str, xr.DataArray] |
| 60 | + Dictionary containing predictions and uncertainties |
| 61 | + """ |
| 62 | + # Convert inputs to numpy arrays |
| 63 | + arrays = [np.asarray(x) for x in (lat, lon, pres, temp, psal, doxy)] |
| 64 | + lat, lon, pres, temp, psal, doxy = arrays |
| 65 | + |
| 66 | + # Get array shape and number of elements |
| 67 | + shape = pres.shape |
| 68 | + nol = pres.size |
| 69 | + |
| 70 | + # Set default edoxy if not provided |
| 71 | + if edoxy is None: |
| 72 | + edoxy = 0.01 * doxy |
| 73 | + |
| 74 | + # Expand scalar error values |
| 75 | + errors = [epres, etemp, epsal, edoxy] |
| 76 | + errors = [np.full(nol, e) if np.isscalar(e) else np.asarray(e).flatten() |
| 77 | + for e in errors] |
| 78 | + epres, etemp, epsal, edoxy = errors |
| 79 | + |
| 80 | + # Define parameters and their properties |
| 81 | + paramnames = ['AT', 'CT', 'pH', 'pCO2', 'NO3', 'PO4', 'SiOH4'] |
| 82 | + inputsigma = np.array([6, 4, 0.005, np.nan, 0.02, 0.02, 0.02]) |
| 83 | + betaipCO2 = np.array([-3.114e-05, 1.087e-01, -7.899e+01]) |
| 84 | + |
| 85 | + # Adjust pH uncertainty |
| 86 | + inputsigma[2] = np.sqrt(0.005**2 + 0.01**2) |
| 87 | + |
| 88 | + # Set parameters to calculate |
| 89 | + if param is None: |
| 90 | + param = paramnames |
| 91 | + paramflag = np.array([p in param for p in paramnames]) |
| 92 | + |
| 93 | + # Prepare input data |
| 94 | + year = calculate_decimal_year(np.asarray(gtime).flatten()) |
| 95 | + adj_lat = adjust_arctic_latitude(lat.flatten(), lon.flatten()) |
| 96 | + |
| 97 | + # Create input matrix |
| 98 | + data = np.column_stack([ |
| 99 | + year, |
| 100 | + adj_lat / 90, |
| 101 | + np.abs(1 - np.mod(lon.flatten() - 110, 360) / 180), |
| 102 | + np.abs(1 - np.mod(lon.flatten() - 20, 360) / 180), |
| 103 | + temp.flatten(), |
| 104 | + psal.flatten(), |
| 105 | + doxy.flatten(), |
| 106 | + pres.flatten() / 2e4 + 1 / ((1 + np.exp(-pres.flatten() / 300))**3) |
| 107 | + ]) |
| 108 | + |
| 109 | + out = {} |
| 110 | + |
| 111 | + # Process each parameter |
| 112 | + for i, param_name in enumerate(paramnames): |
| 113 | + if not paramflag[i]: |
| 114 | + continue |
| 115 | + |
| 116 | + # Load weights |
| 117 | + inwgts = load_weight_file(weights_dir, param_name) |
| 118 | + noparsets = inwgts.shape[1] - 1 |
| 119 | + |
| 120 | + # Determine input normalization based on parameter type |
| 121 | + if i > 3: # nutrients |
| 122 | + ni = data[:, 1:].shape[1] |
| 123 | + ioffset = -1 |
| 124 | + mw = inwgts[:ni+1, -1] |
| 125 | + sw = inwgts[ni+1:2*ni+2, -1] |
| 126 | + data_N = (data[:, 1:] - mw[:ni]) / sw[:ni] |
| 127 | + else: # carbonate system |
| 128 | + ni = data.shape[1] |
| 129 | + ioffset = 0 |
| 130 | + mw = inwgts[:ni+1, -1] |
| 131 | + sw = inwgts[ni+1:2*ni+2, -1] |
| 132 | + data_N = (data - mw[:ni]) / sw[:ni] |
| 133 | + |
| 134 | + # Extract weights and prepare arrays |
| 135 | + wgts = inwgts[3, :noparsets] |
| 136 | + betaciw = inwgts[2*ni+2:, -1] |
| 137 | + betaciw = betaciw[~np.isnan(betaciw)] |
| 138 | + |
| 139 | + # Preallocate arrays |
| 140 | + cval = np.full((nol, noparsets), np.nan) |
| 141 | + cvalcy = np.full(noparsets, np.nan) |
| 142 | + inval = np.full((nol, ni, noparsets), np.nan) |
| 143 | + |
| 144 | + # Process each network in committee |
| 145 | + for l in range(noparsets): |
| 146 | + nlayerflag = 1 + bool(inwgts[1, l]) |
| 147 | + nl1 = int(inwgts[0, l]) |
| 148 | + nl2 = int(inwgts[1, l]) |
| 149 | + beta = inwgts[2, l] |
| 150 | + |
| 151 | + # Extract weights |
| 152 | + idx = 4 |
| 153 | + w1 = inwgts[idx:idx + nl1 * ni, l].reshape(nl1, ni, order='F') # Here, order='F'needed for sure to proper do the calculation as in matlab version ! |
| 154 | + idx += nl1*ni |
| 155 | + b1 = inwgts[idx:idx + nl1, l] |
| 156 | + idx += nl1 |
| 157 | + w2 = inwgts[idx:idx + nl2*nl1, l].reshape(nl2, nl1, order='F') |
| 158 | + idx += nl2*nl1 |
| 159 | + b2 = inwgts[idx:idx + nl2, l] |
| 160 | + |
| 161 | + if nlayerflag == 2: |
| 162 | + idx += nl2 |
| 163 | + w3 = inwgts[idx:idx + nl2, l].reshape(1, nl2, order='F') |
| 164 | + idx += nl2 |
| 165 | + b3 = inwgts[idx:idx + 1, l] |
| 166 | + |
| 167 | + # Forward pass |
| 168 | + a = np.dot(data_N, w1.T) + b1 |
| 169 | + if nlayerflag == 1: |
| 170 | + y = np.dot(np.tanh(a), w2.T) + b2 |
| 171 | + else: |
| 172 | + b = np.dot(np.tanh(a), w2.T) + b2 |
| 173 | + y = np.dot(np.tanh(b), w3.T) + b3 |
| 174 | + |
| 175 | + # Store results |
| 176 | + cval[:, l] = y.flatten() |
| 177 | + cvalcy[l] = 1/beta |
| 178 | + |
| 179 | + # Calculate input effects |
| 180 | + x1 = w1[None, :, :] * (1 - np.tanh(a)[:, :, None]**2) |
| 181 | + # jusque-là okay |
| 182 | + if nlayerflag == 1: |
| 183 | + #inx = np.einsum('ij,jkl->ikl', w2, x1) |
| 184 | + inx = np.einsum('ij,...jk->...ik', w2, x1)[:, 0, :] |
| 185 | + else: |
| 186 | + x2 = w2[None, :, :] * (1 - np.tanh(b)[:, :, None]**2) |
| 187 | + #inx = np.einsum('ij,jkl,kln->ikn', w3, x2, x1, order='F') |
| 188 | + inx = np.einsum('ij,...jk,...kl->...il', w3, x2, x1)[:, 0, :] |
| 189 | + inval[:, :, l] = inx |
| 190 | + |
| 191 | + # Denormalization |
| 192 | + cval = cval * sw[ni] + mw[ni] |
| 193 | + cvalcy = cvalcy * sw[ni]**2 |
| 194 | + |
| 195 | + # Calculate committee statistics |
| 196 | + V1 = np.sum(wgts) |
| 197 | + V2 = np.sum(wgts**2) |
| 198 | + pred = np.sum(wgts[None, :] * cval, axis=1) / V1 |
| 199 | + |
| 200 | + # Calculate uncertainties |
| 201 | + cvalcu = np.sum(wgts[None, :] * (cval - pred[:, None])**2, axis=1) / (V1 - V2/V1) |
| 202 | + cvalcib = np.sum(wgts * cvalcy) / V1 |
| 203 | + cvalciw = np.polyval(betaciw, np.sqrt(cvalcu))**2 |
| 204 | + |
| 205 | + # Calculate input effects |
| 206 | + inx = np.sum(wgts[None, None, :] * inval, axis=2) / V1 |
| 207 | + #inx = sw[ni] / sw[:ni] * inx |
| 208 | + inx = np.tile((sw[ni] / sw[0:ni].T), (nol, 1)) * inx |
| 209 | + |
| 210 | + # Pressure scaling |
| 211 | + ddp = 1/2e4 + 1/((1 + np.exp(-pres.flatten()/300))**4) * np.exp(-pres.flatten()/300)/100 # TODO order='F' ? |
| 212 | + inx[:, 7+ioffset] *= ddp |
| 213 | + |
| 214 | + # Calculate input variance |
| 215 | + error_matrix = np.column_stack([etemp, epsal, edoxy, epres]) |
| 216 | + cvalcin = np.sum(inx[:, 4+ioffset:8+ioffset]**2 * error_matrix**2, axis=1) |
| 217 | + |
| 218 | + # Calculate measurement uncertainty |
| 219 | + if i > 3: |
| 220 | + cvalcimeas = (inputsigma[i] * pred)**2 |
| 221 | + elif i == 3: |
| 222 | + cvalcimeas = np.polyval(betaipCO2, pred)**2 |
| 223 | + else: |
| 224 | + cvalcimeas = inputsigma[i]**2 |
| 225 | + |
| 226 | + # Calculate total uncertainty |
| 227 | + uncertainty = np.sqrt(cvalcimeas + cvalcib + cvalciw + cvalcu + cvalcin) |
| 228 | + |
| 229 | + # Create numpy arrays |
| 230 | + out[param_name] = np.reshape(pred, shape) |
| 231 | + out[f'{param_name}_ci'] = np.reshape(uncertainty, shape) |
| 232 | + out[f'{param_name}_cim'] = np.sqrt(cvalcimeas) |
| 233 | + out[f'{param_name}_cin'] = np.reshape(np.sqrt(cvalcib + cvalciw + cvalcu), shape) |
| 234 | + out[f'{param_name}_cii'] = np.reshape(np.sqrt(cvalcin), shape) |
| 235 | + |
| 236 | + # TODO: should be implemented here with xarray such as |
| 237 | + #coords = {'depth': pres.reshape(shape)} |
| 238 | + #out[param_name] = xr.DataArray( |
| 239 | + # pred.reshape(shape), |
| 240 | + # coords=coords, |
| 241 | + # dims=['depth'], |
| 242 | + # name=param_name |
| 243 | + #) |
| 244 | + |
| 245 | + # pCO2 |
| 246 | + if i == 3: |
| 247 | + # ipCO2 = 'DIC' / umol kg-1 -> pCO2 / uatm |
| 248 | + outcalc = pyco2.sys( |
| 249 | + par1=2300, |
| 250 | + par2=out[param_name], |
| 251 | + par1_type=1, |
| 252 | + par2_type=2, |
| 253 | + salinity=35., |
| 254 | + temperature=25., |
| 255 | + temperature_out=np.nan, |
| 256 | + pressure_out=0., |
| 257 | + pressure_atmosphere_out=np.nan, |
| 258 | + total_silicate=0., |
| 259 | + total_phosphate=0., |
| 260 | + opt_pH_scale=1., |
| 261 | + opt_k_carbonic=10., |
| 262 | + opt_k_bisulfate=1., |
| 263 | + grads_of=["pCO2"], |
| 264 | + grads_wrt=["par2"], |
| 265 | + ) |
| 266 | + |
| 267 | + out[f'{paramnames[i]}'] = outcalc['pCO2'] |
| 268 | + |
| 269 | + # epCO2 = dpCO2/dDIC * e'DIC' |
| 270 | + for unc in ['_ci', '_cim', '_cin', '_cii']: |
| 271 | + out[param_name + unc] = outcalc['d_pCO2__d_par2'] * out[param_name + unc] |
| 272 | + |
| 273 | + return out |
0 commit comments