Skip to content

Commit e02fda9

Browse files
authored
Merge pull request #30 from ClimateImpactLab/fusion
WIP: Merge Release-2.0 changes into master
2 parents 9a4fbcb + acab071 commit e02fda9

File tree

12 files changed

+48986
-33
lines changed

12 files changed

+48986
-33
lines changed

openest/curves/smart_linextrap.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ def __call__(self, ds):
6060
values = self.curve(ds)
6161
indeps = np.stack((ds[indepvar] for indepvar in self.indepvars), axis=-1)
6262

63-
return linextrap.replace_oob(values, indeps, self.curve.get_univariate(), self.bounds, self.margins, self.scaling)
63+
return linextrap.replace_oob(values, indeps, self.curve.univariate, self.bounds, self.margins, self.scaling)
6464

65-
def get_univariate(self):
65+
@property
66+
def univariate(self):
6667
"""Return a UnivariateCurve version of this curve."""
67-
return linextrap.LinearExtrapolationCurve(self.curve.get_univariate(), self.bounds, self.margins,
68+
return linextrap.LinearExtrapolationCurve(self.curve.univariate, self.bounds, self.margins,
6869
self.scaling, lambda xs: xs[:, 0])
6970

7071

openest/curves/ushape_numeric.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import numpy as np
2+
import xarray as xr
3+
from openest.generate import fast_dataset
24
from openest.curves.basic import UnivariateCurve
35

46

@@ -33,11 +35,10 @@ def __init__(self, curve, midtemp, gettas, ordered='maintain', fillxxs=None, fil
3335

3436
def __call__(self, xs):
3537
values = self.curve(xs)
36-
tas = np.array(self.gettas(xs))
38+
tas = tas_original = np.array(self.gettas(xs))
3739

3840
# Add in the grid for completeness
3941
if len(self.fillxxs) > 0:
40-
tas_saved = tas
4142
tas = np.concatenate((tas, self.fillxxs))
4243
values = np.concatenate((values, self.fillyys))
4344

@@ -65,10 +66,10 @@ def __call__(self, xs):
6566
tasorder = np.empty(len(increasing))
6667
tasorder[order] = increasing
6768
# Return just the given values
68-
return tasorder[:len(xs)]
69+
return tasorder[:len(tas_original)]
6970
else:
7071
if len(self.fillxxs) > 0:
71-
tokeeps = order < len(tas_saved)
72+
tokeeps = order < len(tas_original)
7273
lowkeeps = tokeeps[orderedtas < self.midtemp]
7374
highkeeps = tokeeps[orderedtas >= self.midtemp]
7475

@@ -157,7 +158,10 @@ def __call__(self, xs):
157158
highindicesofordered = np.maximum.accumulate(highindicesofordered)
158159

159160
# Construct the results
160-
if len(xs.shape) == 2:
161+
if isinstance(xs, xr.Dataset):
162+
newxs = fast_dataset.reorder_coord(xs, 'time', np.concatenate((order[lowindicesofordered], order[highindicesofordered])))
163+
increasingresults = self.tmarginal_curve(newxs) # ordered low..., high...
164+
elif len(xs.shape) == 2:
161165
increasingresults = np.concatenate((self.tmarginal_curve(xs[order[lowindicesofordered], :]),
162166
self.tmarginal_curve(xs[order[highindicesofordered], :]))) # ordered low..., high...
163167
else:

openest/generate/fast_dataset.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def isel(self, **kwargs):
146146

147147
return FastDataset(newdata_vars, newcoords, self.attrs)
148148

149+
def load(self):
150+
pass # always ready
151+
149152
def __getitem__(self, name):
150153
if name not in self._variables:
151154
if name == 'time.year' and 'time' in self._variables:
@@ -472,6 +475,40 @@ def assert_index_equal(one, two):
472475
assert np.array_equal(two, one), "Not equal: %s <> %s" % (str(two), str(one))
473476
return one
474477

478+
def reorder_coord(ds, dim, indices):
479+
assert dim in ['time', 'region'] # assume a time x region ds
480+
assert dim in ds.coords
481+
482+
newvars = {}
483+
for var in ds.variables:
484+
try:
485+
dims = ds[var].coords
486+
if len(dims) == 1:
487+
if dim not in dims:
488+
newvars[var] = ds[var]
489+
else:
490+
newvars[var] = ([dim], ds[var].values[indices])
491+
elif len(dims) == 2 and dim in dims:
492+
if dims.index(dim) == 0:
493+
newvars[var] = (dims, ds[var].values[:, indices])
494+
else:
495+
newvars[var] = (dims, ds[var].values[indices, :])
496+
else:
497+
tindex = list(dims).index(dim)
498+
allindices = [slice(None)] * len(dims)
499+
allindices[tindex] = indices
500+
newvars[var] = (dims, ds[var].values[tuple(allindices)])
501+
except Exception:
502+
print("Failed to reorder %s for %s" % (var, ds))
503+
raise
504+
505+
coords = ds.coords
506+
coords[dim] = ds[dim][indices]
507+
newds = ds.__class__(newvars, coords=coords) # work for Dataset or FastDataset
508+
newds.load()
509+
510+
return newds
511+
475512
FastDataArray.__array_priority__ = 80
476513
xr.core.ops.inject_binary_ops(FastDataArray)
477514

openest/generate/functions.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -868,11 +868,15 @@ class AuxiliaryResult(calculation.Calculation):
868868
"""
869869
Produce an additional output, but then pass the main result on.
870870
"""
871-
def __init__(self, subcalc_main, subcalc_aux, auxname):
872-
super(AuxiliaryResult, self).__init__([subcalc_main.unitses[0], subcalc_aux.unitses[0]] + subcalc_main.unitses[1:])
871+
def __init__(self, subcalc_main, subcalc_aux, auxname, keeplastonly=True):
872+
if keeplastonly:
873+
super(AuxiliaryResult, self).__init__([subcalc_main.unitses[0], subcalc_aux.unitses[0]] + subcalc_main.unitses[1:])
874+
else:
875+
super(AuxiliaryResult, self).__init__([subcalc_main.unitses[0]] + subcalc_aux.unitses + subcalc_main.unitses[1:])
873876
self.subcalc_main = subcalc_main
874877
self.subcalc_aux = subcalc_aux
875878
self.auxname = auxname
879+
self.keeplastonly = keeplastonly
876880

877881
def format(self, lang, *args, **kwargs):
878882
beforeauxlen = len(formatting.format_labels)
@@ -884,7 +888,7 @@ def format(self, lang, *args, **kwargs):
884888
def apply(self, region, *args, **kwargs):
885889
subapp_main = self.subcalc_main.apply(region, *args, **kwargs)
886890
subapp_aux = self.subcalc_aux.apply(region, *args, **kwargs)
887-
return AuxiliaryResultApplication(region, subapp_main, subapp_aux)
891+
return AuxiliaryResultApplication(region, subapp_main, subapp_aux, self.keeplastonly)
888892

889893
def partial_derivative(self, covariate, covarunit):
890894
"""
@@ -899,7 +903,10 @@ def column_info(self):
899903
infos_aux = self.subcalc_aux.column_info()
900904
infos_aux[0]['name'] = self.auxname
901905

902-
return [infos_main[0], infos_aux[0]] + infos_main[1:]
906+
if self.keeplastonly:
907+
return [infos_main[0], infos_aux[0]] + infos_main[1:]
908+
else:
909+
return [infos_main[0]] + infos_aux + infos_main[1:]
903910

904911
@staticmethod
905912
def describe():
@@ -927,16 +934,20 @@ class AuxiliaryResultApplication(calculation.Application):
927934
"""
928935
Perform both main and auxiliary calculation, and order as main[0], aux, main[1:]
929936
"""
930-
def __init__(self, region, subapp_main, subapp_aux):
937+
def __init__(self, region, subapp_main, subapp_aux, keeplastonly):
931938
super(AuxiliaryResultApplication, self).__init__(region)
932939
self.subapp_main = subapp_main
933940
self.subapp_aux = subapp_aux
941+
self.keeplastonly = keeplastonly
934942

935943
def push(self, ds):
936944
for yearresult in self.subapp_main.push(ds):
937945
for yearresult_aux in self.subapp_aux.push(ds):
938-
next # Just take the last one
939-
yield list(yearresult[0:2]) + [yearresult_aux[1]] + list(yearresult[2:])
946+
pass # Just take the last one
947+
if self.keeplastonly:
948+
yield list(yearresult[0:2]) + [yearresult_aux[1]] + list(yearresult[2:])
949+
else:
950+
yield list(yearresult[0:2]) + list(yearresult_aux[1:]) + list(yearresult[2:])
940951

941952
def done(self):
942953
self.subapp_aux.done()

openest/generate/smart_curve.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def __init__(self):
2222
def __call__(self, ds):
2323
raise NotImplementedError("call not implemented")
2424

25-
def get_univariate(self):
26-
raise NotImplementedError("get_univariate not implemented")
25+
@property
26+
def univariate(self):
27+
raise NotImplementedError("univariate not implemented")
2728

2829
def format(self, lang):
2930
raise NotImplementedError()
@@ -120,7 +121,8 @@ def __call__(self, ds):
120121

121122
return result
122123

123-
def get_univariate(self):
124+
@property
125+
def univariate(self):
124126
return curve.ZeroInterceptPolynomialCurve([-np.inf, np.inf], self.coeffs)
125127

126128
def format(self, lang):
@@ -171,6 +173,7 @@ class SumByTimePolynomialCurve(SmartCurve):
171173
def __init__(self, coeffmat, variables, allow_raising=False, descriptions=None):
172174
super(SumByTimePolynomialCurve, self).__init__()
173175
self.coeffmat = coeffmat # K x T
176+
assert len(self.coeffmat.shape) == 2
174177
self.variables = variables
175178
self.allow_raising = allow_raising
176179
if descriptions is None:
@@ -259,7 +262,8 @@ def __call__(self, ds):
259262

260263
raise ex
261264

262-
def get_univariate(self):
265+
@property
266+
def univariate(self):
263267
return curve.CubicSplineCurve(self.knots, self.coeffs)
264268

265269
class TransformCoefficientsCurve(SmartCurve):
@@ -351,22 +355,30 @@ def __init__(self, curve, offset):
351355
self.offset = offset
352356

353357
def __call__(self, ds):
354-
return self.curve(ds) - self.offset
358+
return self.curve(ds) + self.offset
355359

356-
def get_univariate(self):
357-
return curve.ShiftedCurve(self.curve.get_univariate(), self.offset)
360+
@property
361+
def univariate(self):
362+
return curve.ShiftedCurve(self.curve.univariate, self.offset)
358363

359364
def format(self, lang):
360-
return formatting.build_recursive({'latex': r"(%s - " + str(self.offset) + ")",
361-
'julia': r"(%s - " + str(self.offset) + ")"},
365+
return formatting.build_recursive({'latex': r"(%s + " + str(self.offset) + ")",
366+
'julia': r"(%s + " + str(self.offset) + ")"},
362367
lang, self.curve)
363368

364369
class ClippedCurve(curve.ClippedCurve, SmartCurve):
365-
def get_univariate(self):
366-
return curve.ClippedCurve(self.curve.get_univariate(), self.cliplow)
370+
@property
371+
def univariate(self):
372+
return curve.ClippedCurve(self.curve.univariate, self.cliplow)
373+
374+
class OtherClippedCurve(curve.OtherClippedCurve, SmartCurve):
375+
@property
376+
def univariate(self):
377+
return curve.OtherClippedCurve(self.clipping_curve.univariate, self.curve.univariate, self.clipy)
367378

368379
class MinimumCurve(curve.MinimumCurve, SmartCurve):
369-
def get_univariate(self):
370-
return curve.MinimumCurve(self.curve1.get_univariate(), self.curve2.get_univariate())
380+
@property
381+
def univariate(self):
382+
return curve.MinimumCurve(self.curve1.univariate, self.curve2.univariate)
371383

372384

openest/generate/yearly.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,26 @@ def describe():
5757
description="Apply the results of a previous calculation to a curve.")
5858

5959
class YearlyCoefficients(Calculation):
60-
def __init__(self, units, curvegen, curve_description, getter=lambda curve: curve.yy, weather_change=lambda region, x: x):
60+
def __init__(self, units, curvegen, curve_description, getter=lambda curve: curve.yy, weather_change=lambda region, x: x, label='response'):
6161
super(YearlyCoefficients, self).__init__([units])
6262
assert isinstance(curvegen, CurveGenerator)
6363

6464
self.curvegen = curvegen
6565
self.curve_description = curve_description
6666
self.getter = getter
6767
self.weather_change = weather_change
68+
self.label = label
6869

6970
def apply(self, region, *args):
7071
def generate(region, year, temps, **kw):
7172
curve = self.curvegen.get_curve(region, year, *args, weather=temps) # Passing in original (not weather-changed) data
7273

7374
coeffs = self.getter(region, year, temps, curve)
74-
if len(temps) == len(coeffs):
75-
result = np.sum(self.weather_change(region, temps).dot(coeffs))
75+
temps2 = self.weather_change(region, temps)
76+
if np.isscalar(temps2) and np.isscalar(coeffs):
77+
result = temps2 * coeffs
78+
elif len(temps) == len(coeffs):
79+
result = np.sum(np.array(temps).dot(coeffs))
7680
else:
7781
raise RuntimeError("Unknown format for temps: " + str(temps.shape) + " <> len " + str(coeffs))
7882

@@ -87,7 +91,7 @@ def generate(region, year, temps, **kw):
8791

8892
def column_info(self):
8993
description = "The combined result of yearly values, with coefficients from %s." % (str(self.curve_description))
90-
return [dict(name='response', title='Direct marginal response', description=description)]
94+
return [dict(name=self.label, title='Direct marginal response', description=description)]
9195

9296
@staticmethod
9397
def describe():

openest/models/curve.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def make_linear_spline_curve(xx, yy, limits):
3636

3737
class FlatCurve(CurveCurve):
3838
def __init__(self, yy):
39+
self.yy = yy
3940
super(FlatCurve, self).__init__([-np.inf, np.inf], lambda x: yy)
4041

4142
class LinearCurve(CurveCurve):
@@ -149,7 +150,7 @@ def __call__(self, xs):
149150
clipping = self.clipping_curve(xs)
150151
ys = [y if y is not None else 0 for y in ys]
151152
clipping = [y if not np.isnan(y) else 0 for y in clipping]
152-
return ys * (clipping > self.clipy)
153+
return ys * (np.array(clipping) > self.clipy)
153154

154155
class MinimumCurve(UnivariateCurve):
155156
def __init__(self, curve1, curve2):
@@ -183,6 +184,7 @@ def __call__(self, xs):
183184
class PiecewiseCurve(UnivariateCurve):
184185
def __init__(self, curves, knots, xtrans=lambda x: x):
185186
super(PiecewiseCurve, self).__init__(knots)
187+
assert len(curves) == len(knots) - 1
186188
self.curves = curves
187189
self.knots = knots
188190
self.xtrans = xtrans # for example, to select first column

0 commit comments

Comments
 (0)