Skip to content

Commit 788f8f4

Browse files
authored
Merge pull request #5 from RaphaelBajon/RaphaelBajon-shape-1
Fix shape issue for pCO2 uncertainty and upload tests
2 parents 61a85d4 + 7f5774e commit 788f8f4

File tree

5 files changed

+390
-5
lines changed

5 files changed

+390
-5
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
import numpy as np
2+
import xarray as xr
3+
from typing import Union, List, Dict, Optional, Tuple
4+
import PyCO2SYS as pyco2
5+
6+
7+
from .utils import calculate_decimal_year, adjust_arctic_latitude, load_weight_file
8+
9+
10+
def canyonb(
11+
gtime: Union[np.ndarray, List],
12+
lat: np.ndarray,
13+
lon: np.ndarray,
14+
pres: np.ndarray,
15+
temp: np.ndarray,
16+
psal: np.ndarray,
17+
doxy: np.ndarray,
18+
param: Optional[List[str]] = None,
19+
epres: Optional[float] = 0.5,
20+
etemp: Optional[float] = 0.005,
21+
epsal: Optional[float] = 0.005,
22+
edoxy: Optional[Union[float, np.ndarray]] = None,
23+
weights_dir: str = None
24+
) -> Dict[str, xr.DataArray]:
25+
# TODO order='F' should be checked if needed for 2d-4d arrays as inputs
26+
# Using xarray:
27+
# For now 2d-matrix not handle because we are creating xarray matrix at the end with multiple dimension wihtout multiple dimensions names. Should address this issue.
28+
# This might be needed but not sure since all reshape, flatten and eisum are order='C'(or 'K') for indexing right now.
29+
"""
30+
CANYON-B neural network prediction for ocean parameters.
31+
32+
Parameters
33+
----------
34+
gtime : array-like
35+
Date (UTC) as datetime objects or decimal years
36+
lat : array-like
37+
Latitude (-90 to 90)
38+
lon : array-like
39+
Longitude (-180 to 180 or 0 to 360)
40+
pres : array-like
41+
Pressure (dbar)
42+
temp : array-like
43+
In-situ temperature (°C)
44+
psal : array-like
45+
Salinity
46+
doxy : array-like
47+
Dissolved oxygen (µmol/kg)
48+
param : list of str, optional
49+
Parameters to calculate. Default calculates all.
50+
epres, etemp, epsal : float, optional
51+
Input errors
52+
edoxy : float or array-like, optional
53+
Oxygen input error (default: 1% of doxy)
54+
weights_dir : str
55+
Directory containing weight files
56+
57+
Returns
58+
-------
59+
Dict[str, xr.DataArray]
60+
Dictionary containing predictions and uncertainties
61+
"""
62+
# Convert inputs to numpy arrays
63+
arrays = [np.asarray(x) for x in (lat, lon, pres, temp, psal, doxy)]
64+
lat, lon, pres, temp, psal, doxy = arrays
65+
66+
# Get array shape and number of elements
67+
shape = pres.shape
68+
nol = pres.size
69+
70+
# Set default edoxy if not provided
71+
if edoxy is None:
72+
edoxy = 0.01 * doxy
73+
74+
# Expand scalar error values
75+
errors = [epres, etemp, epsal, edoxy]
76+
errors = [np.full(nol, e) if np.isscalar(e) else np.asarray(e).flatten()
77+
for e in errors]
78+
epres, etemp, epsal, edoxy = errors
79+
80+
# Define parameters and their properties
81+
paramnames = ['AT', 'CT', 'pH', 'pCO2', 'NO3', 'PO4', 'SiOH4']
82+
inputsigma = np.array([6, 4, 0.005, np.nan, 0.02, 0.02, 0.02])
83+
betaipCO2 = np.array([-3.114e-05, 1.087e-01, -7.899e+01])
84+
85+
# Adjust pH uncertainty
86+
inputsigma[2] = np.sqrt(0.005**2 + 0.01**2)
87+
88+
# Set parameters to calculate
89+
if param is None:
90+
param = paramnames
91+
paramflag = np.array([p in param for p in paramnames])
92+
93+
# Prepare input data
94+
year = calculate_decimal_year(np.asarray(gtime).flatten())
95+
adj_lat = adjust_arctic_latitude(lat.flatten(), lon.flatten())
96+
97+
# Create input matrix
98+
data = np.column_stack([
99+
year,
100+
adj_lat / 90,
101+
np.abs(1 - np.mod(lon.flatten() - 110, 360) / 180),
102+
np.abs(1 - np.mod(lon.flatten() - 20, 360) / 180),
103+
temp.flatten(),
104+
psal.flatten(),
105+
doxy.flatten(),
106+
pres.flatten() / 2e4 + 1 / ((1 + np.exp(-pres.flatten() / 300))**3)
107+
])
108+
109+
out = {}
110+
111+
# Process each parameter
112+
for i, param_name in enumerate(paramnames):
113+
if not paramflag[i]:
114+
continue
115+
116+
# Load weights
117+
inwgts = load_weight_file(weights_dir, param_name)
118+
noparsets = inwgts.shape[1] - 1
119+
120+
# Determine input normalization based on parameter type
121+
if i > 3: # nutrients
122+
ni = data[:, 1:].shape[1]
123+
ioffset = -1
124+
mw = inwgts[:ni+1, -1]
125+
sw = inwgts[ni+1:2*ni+2, -1]
126+
data_N = (data[:, 1:] - mw[:ni]) / sw[:ni]
127+
else: # carbonate system
128+
ni = data.shape[1]
129+
ioffset = 0
130+
mw = inwgts[:ni+1, -1]
131+
sw = inwgts[ni+1:2*ni+2, -1]
132+
data_N = (data - mw[:ni]) / sw[:ni]
133+
134+
# Extract weights and prepare arrays
135+
wgts = inwgts[3, :noparsets]
136+
betaciw = inwgts[2*ni+2:, -1]
137+
betaciw = betaciw[~np.isnan(betaciw)]
138+
139+
# Preallocate arrays
140+
cval = np.full((nol, noparsets), np.nan)
141+
cvalcy = np.full(noparsets, np.nan)
142+
inval = np.full((nol, ni, noparsets), np.nan)
143+
144+
# Process each network in committee
145+
for l in range(noparsets):
146+
nlayerflag = 1 + bool(inwgts[1, l])
147+
nl1 = int(inwgts[0, l])
148+
nl2 = int(inwgts[1, l])
149+
beta = inwgts[2, l]
150+
151+
# Extract weights
152+
idx = 4
153+
w1 = inwgts[idx:idx + nl1 * ni, l].reshape(nl1, ni, order='F') # Here, order='F'needed for sure to proper do the calculation as in matlab version !
154+
idx += nl1*ni
155+
b1 = inwgts[idx:idx + nl1, l]
156+
idx += nl1
157+
w2 = inwgts[idx:idx + nl2*nl1, l].reshape(nl2, nl1, order='F')
158+
idx += nl2*nl1
159+
b2 = inwgts[idx:idx + nl2, l]
160+
161+
if nlayerflag == 2:
162+
idx += nl2
163+
w3 = inwgts[idx:idx + nl2, l].reshape(1, nl2, order='F')
164+
idx += nl2
165+
b3 = inwgts[idx:idx + 1, l]
166+
167+
# Forward pass
168+
a = np.dot(data_N, w1.T) + b1
169+
if nlayerflag == 1:
170+
y = np.dot(np.tanh(a), w2.T) + b2
171+
else:
172+
b = np.dot(np.tanh(a), w2.T) + b2
173+
y = np.dot(np.tanh(b), w3.T) + b3
174+
175+
# Store results
176+
cval[:, l] = y.flatten()
177+
cvalcy[l] = 1/beta
178+
179+
# Calculate input effects
180+
x1 = w1[None, :, :] * (1 - np.tanh(a)[:, :, None]**2)
181+
# jusque-là okay
182+
if nlayerflag == 1:
183+
#inx = np.einsum('ij,jkl->ikl', w2, x1)
184+
inx = np.einsum('ij,...jk->...ik', w2, x1)[:, 0, :]
185+
else:
186+
x2 = w2[None, :, :] * (1 - np.tanh(b)[:, :, None]**2)
187+
#inx = np.einsum('ij,jkl,kln->ikn', w3, x2, x1, order='F')
188+
inx = np.einsum('ij,...jk,...kl->...il', w3, x2, x1)[:, 0, :]
189+
inval[:, :, l] = inx
190+
191+
# Denormalization
192+
cval = cval * sw[ni] + mw[ni]
193+
cvalcy = cvalcy * sw[ni]**2
194+
195+
# Calculate committee statistics
196+
V1 = np.sum(wgts)
197+
V2 = np.sum(wgts**2)
198+
pred = np.sum(wgts[None, :] * cval, axis=1) / V1
199+
200+
# Calculate uncertainties
201+
cvalcu = np.sum(wgts[None, :] * (cval - pred[:, None])**2, axis=1) / (V1 - V2/V1)
202+
cvalcib = np.sum(wgts * cvalcy) / V1
203+
cvalciw = np.polyval(betaciw, np.sqrt(cvalcu))**2
204+
205+
# Calculate input effects
206+
inx = np.sum(wgts[None, None, :] * inval, axis=2) / V1
207+
#inx = sw[ni] / sw[:ni] * inx
208+
inx = np.tile((sw[ni] / sw[0:ni].T), (nol, 1)) * inx
209+
210+
# Pressure scaling
211+
ddp = 1/2e4 + 1/((1 + np.exp(-pres.flatten()/300))**4) * np.exp(-pres.flatten()/300)/100 # TODO order='F' ?
212+
inx[:, 7+ioffset] *= ddp
213+
214+
# Calculate input variance
215+
error_matrix = np.column_stack([etemp, epsal, edoxy, epres])
216+
cvalcin = np.sum(inx[:, 4+ioffset:8+ioffset]**2 * error_matrix**2, axis=1)
217+
218+
# Calculate measurement uncertainty
219+
if i > 3:
220+
cvalcimeas = (inputsigma[i] * pred)**2
221+
elif i == 3:
222+
cvalcimeas = np.polyval(betaipCO2, pred)**2
223+
else:
224+
cvalcimeas = inputsigma[i]**2
225+
226+
# Calculate total uncertainty
227+
uncertainty = np.sqrt(cvalcimeas + cvalcib + cvalciw + cvalcu + cvalcin)
228+
229+
# Create numpy arrays
230+
out[param_name] = np.reshape(pred, shape)
231+
out[f'{param_name}_ci'] = np.reshape(uncertainty, shape)
232+
out[f'{param_name}_cim'] = np.sqrt(cvalcimeas)
233+
out[f'{param_name}_cin'] = np.reshape(np.sqrt(cvalcib + cvalciw + cvalcu), shape)
234+
out[f'{param_name}_cii'] = np.reshape(np.sqrt(cvalcin), shape)
235+
236+
# TODO: should be implemented here with xarray such as
237+
#coords = {'depth': pres.reshape(shape)}
238+
#out[param_name] = xr.DataArray(
239+
# pred.reshape(shape),
240+
# coords=coords,
241+
# dims=['depth'],
242+
# name=param_name
243+
#)
244+
245+
# pCO2
246+
if i == 3:
247+
# ipCO2 = 'DIC' / umol kg-1 -> pCO2 / uatm
248+
outcalc = pyco2.sys(
249+
par1=2300,
250+
par2=out[param_name],
251+
par1_type=1,
252+
par2_type=2,
253+
salinity=35.,
254+
temperature=25.,
255+
temperature_out=np.nan,
256+
pressure_out=0.,
257+
pressure_atmosphere_out=np.nan,
258+
total_silicate=0.,
259+
total_phosphate=0.,
260+
opt_pH_scale=1.,
261+
opt_k_carbonic=10.,
262+
opt_k_bisulfate=1.,
263+
grads_of=["pCO2"],
264+
grads_wrt=["par2"],
265+
)
266+
267+
out[f'{paramnames[i]}'] = outcalc['pCO2']
268+
269+
# epCO2 = dpCO2/dDIC * e'DIC'
270+
for unc in ['_ci', '_cim', '_cin', '_cii']:
271+
out[param_name + unc] = outcalc['d_pCO2__d_par2'] * out[param_name + unc]
272+
273+
return out

canyonbpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .core import canyonb
22
from .utils import calculate_decimal_year, adjust_arctic_latitude
33

4-
__version__ = "0.1.1"
4+
__version__ = "0.2.1"
55
__all__ = [
66
"canyonb",
77
"calculate_decimal_year",

canyonbpy/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def canyonb(
267267
out[f'{paramnames[i]}'] = outcalc['pCO2']
268268

269269
# epCO2 = dpCO2/dDIC * e'DIC'
270-
for unc in ['_ci', '_cim', '_cin', '_cii']:
271-
out[param_name + unc] = outcalc['d_pCO2__d_par2'] * out[param_name + unc]
270+
for unc in ['_ci', '_cin', '_cii']:
271+
out[param_name + unc] = outcalc['d_pCO2__d_par2'] * out[param_name + unc]
272+
273+
out[param_name + '_cim'] = outcalc['d_pCO2__d_par2'] * np.reshape(out[param_name + '_cim'], shape)
272274

273275
return out

0 commit comments

Comments
 (0)