Skip to content

Commit d9bdae3

Browse files
authored
Merge pull request #83 from ClimateImpactLab/dscim-v0.4.0_fixes
Fix chunking issues in sum_AMEL and reduce_damages
2 parents 152ae4f + 43b7843 commit d9bdae3

File tree

4 files changed

+184
-15
lines changed

4 files changed

+184
-15
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [0.4.0] - Unreleased
88
### Added
9+
- Functions to concatenate input damages across batches. ([PR #83](https://github.com/ClimateImpactLab/dscim/pull/83), [@davidrzhdu](https://github.com/davidrzhdu))
910
- New unit tests for [dscim/utils/input_damages.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/preprocessing/input_damages.py). ([PR #68](https://github.com/ClimateImpactLab/dscim/pull/68), [@davidrzhdu](https://github.com/davidrzhdu))
1011
- New unit tests for [dscim/utils/rff.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/utils/rff.py). ([PR #73](https://github.com/ClimateImpactLab/dscim/pull/73), [@JMGilbert](https://github.com/JMGilbert))
1112
- New unit tests for [dscim/dscim/preprocessing.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/preprocessing/preprocessing.py). ([PR #67](https://github.com/ClimateImpactLab/dscim/pull/67), [@JMGilbert](https://github.com/JMGilbert))
@@ -23,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2324
- Remove old/unnecessary files. ([PR #57](https://github.com/ClimateImpactLab/dscim/pull/57), [@JMGilbert](https://github.com/JMGilbert))
2425
- Remove unused “save_path” and “ec_cls” from `read_energy_files_parallel()`. ([PR #56](https://github.com/ClimateImpactLab/dscim/pull/56), [@davidrzhdu](https://github.com/davidrzhdu))
2526
### Fixed
27+
- Make all input damages output files with correct chunksizes. ([PR #83](https://github.com/ClimateImpactLab/dscim/pull/83), [@JMGilbert](https://github.com/JMGilbert))
2628
- Add `.load()` to every loading of population data from EconVars. ([PR #82](https://github.com/ClimateImpactLab/dscim/pull/82), [@davidrzhdu](https://github.com/davidrzhdu))
2729
- Make `compute_ag_damages` function correctly save outputs in float32. ([PR #72](https://github.com/ClimateImpactLab/dscim/pull/72) and [PR #82](https://github.com/ClimateImpactLab/dscim/pull/82), [@davidrzhdu](https://github.com/davidrzhdu))
2830
- Make rff damage functions read in and save out in the proper filepath structure. ([PR #79](https://github.com/ClimateImpactLab/dscim/pull/79), [@JMGilbert](https://github.com/JMGilbert))

src/dscim/preprocessing/input_damages.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import os
6-
import glob
76
import re
87
import logging
98
import warnings
@@ -95,6 +94,50 @@ def _parse_projection_filesys(input_path, query="exists==True"):
9594
return df.query(query)
9695

9796

97+
def concatenate_damage_output(damage_dir, basename, save_path):
98+
"""Concatenate labor/energy damage output across batches.
99+
100+
Parameters
101+
----------
102+
damage_dir str
103+
Directory containing separate labor/energy damage output files by batches.
104+
basename str
105+
Prefix of the damage output filenames (ex. {basename}_batch0.zarr)
106+
save_path str
107+
Path to save concatenated file in .zarr format
108+
"""
109+
paths = [
110+
f"{damage_dir}/{basename}_{b}.zarr"
111+
for b in ["batch" + str(i) for i in range(0, 15)]
112+
]
113+
data = xr.open_mfdataset(paths=paths, engine="zarr")
114+
115+
for v in data:
116+
del data[v].encoding["chunks"]
117+
118+
chunkies = {
119+
"batch": 15,
120+
"rcp": 1,
121+
"gcm": 1,
122+
"model": 1,
123+
"ssp": 1,
124+
"region": -1,
125+
"year": 10,
126+
}
127+
128+
data = data.chunk(chunkies)
129+
130+
for v in list(data.coords.keys()):
131+
if data.coords[v].dtype == object:
132+
data.coords[v] = data.coords[v].astype("unicode")
133+
data.coords["batch"] = data.coords["batch"].astype("unicode")
134+
for v in list(data.variables.keys()):
135+
if data[v].dtype == object:
136+
data[v] = data[v].astype("unicode")
137+
138+
data.to_zarr(save_path, mode="w")
139+
140+
98141
def calculate_labor_impacts(input_path, file_prefix, variable, val_type):
99142
"""Calculate impacts for labor results.
100143
@@ -371,7 +414,7 @@ def process_batch(g):
371414
batches = [ds for ds in batches if ds is not None]
372415
chunkies = {
373416
"rcp": 1,
374-
"region": 24378,
417+
"region": -1,
375418
"gcm": 1,
376419
"year": 10,
377420
"model": 1,
@@ -738,12 +781,21 @@ def prep(
738781
).expand_dims({"gcm": [gcm]})
739782

740783
damages = damages.chunk(
741-
{"batch": 15, "ssp": 1, "model": 1, "rcp": 1, "gcm": 1, "year": 10}
784+
{
785+
"batch": 15,
786+
"ssp": 1,
787+
"model": 1,
788+
"rcp": 1,
789+
"gcm": 1,
790+
"year": 10,
791+
"region": -1,
792+
}
742793
)
743794
damages.coords.update({"batch": [f"batch{i}" for i in damages.batch.values]})
744795

745796
# convert to EPA VSL
746797
damages = damages * 0.90681089
798+
damages = damages.astype(np.float32)
747799

748800
for v in list(damages.coords.keys()):
749801
if damages.coords[v].dtype == object:
@@ -790,6 +842,15 @@ def coastal_inputs(
790842
)
791843
else:
792844
d = d.sel(adapt_type=adapt_type, vsl_valuation=vsl_valuation, drop=True)
845+
chunkies = {
846+
"batch": 15,
847+
"ssp": 1,
848+
"model": 1,
849+
"slr": 1,
850+
"year": 10,
851+
"region": -1,
852+
}
853+
d = d.chunk(chunkies)
793854
d.to_zarr(
794855
f"{path}/coastal_damages_{version}-{adapt_type}-{vsl_valuation}.zarr",
795856
consolidated=True,

src/dscim/preprocessing/preprocessing.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,24 @@ def reduce_damages(
102102
xr.open_zarr(damages).chunks["batch"][0] == 15
103103
), "'batch' dim on damages does not have chunksize of 15. Please rechunk."
104104

105+
if "coastal" not in sector:
106+
chunkies = {
107+
"rcp": 1,
108+
"region": -1,
109+
"gcm": 1,
110+
"year": 10,
111+
"model": 1,
112+
"ssp": 1,
113+
}
114+
else:
115+
chunkies = {
116+
"region": -1,
117+
"slr": 1,
118+
"year": 10,
119+
"model": 1,
120+
"ssp": 1,
121+
}
122+
105123
ce_batch_dims = [i for i in gdppc.dims] + [
106124
i for i in ds.dims if i not in gdppc.dims and i != "batch"
107125
]
@@ -110,15 +128,14 @@ def reduce_damages(
110128
i for i in gdppc.region.values if i in ce_batch_coords["region"]
111129
]
112130
ce_shapes = [len(ce_batch_coords[c]) for c in ce_batch_dims]
113-
ce_chunks = [xr.open_zarr(damages).chunks[c][0] for c in ce_batch_dims]
114131

115132
template = xr.DataArray(
116-
da.empty(ce_shapes, chunks=ce_chunks),
133+
da.empty(ce_shapes),
117134
dims=ce_batch_dims,
118135
coords=ce_batch_coords,
119-
)
136+
).chunk(chunkies)
120137

121-
other = xr.open_zarr(damages)
138+
other = xr.open_zarr(damages).chunk(chunkies)
122139

123140
out = other.map_blocks(
124141
ce_from_chunk,
@@ -205,7 +222,21 @@ def sum_AMEL(
205222
for sector in sectors:
206223
print(f"Opening {sector},{params[sector]['sector_path']}")
207224
ds = xr.open_zarr(params[sector]["sector_path"], consolidated=True)
208-
ds = ds[params[sector][var]].rename(var)
225+
ds = (
226+
ds[params[sector][var]]
227+
.rename(var)
228+
.chunk(
229+
{
230+
"batch": 15,
231+
"ssp": 1,
232+
"model": 1,
233+
"rcp": 1,
234+
"gcm": 1,
235+
"year": 10,
236+
"region": -1,
237+
}
238+
)
239+
)
209240
ds = xr.where(np.isinf(ds), np.nan, ds)
210241
datasets.append(ds)
211242

tests/test_input_damages.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dscim.menu.simple_storage import EconVars
1010
from dscim.preprocessing.input_damages import (
1111
_parse_projection_filesys,
12+
concatenate_damage_output,
1213
calculate_labor_impacts,
1314
concatenate_labor_damages,
1415
calculate_labor_batch_damages,
@@ -31,7 +32,7 @@ def test_parse_projection_filesys(tmp_path):
3132
"""
3233
Test that parse_projection_filesys correctly retrieves projection system output structure
3334
"""
34-
rcp = ["rcp85", "rcp45"]
35+
rcp = ["rcp45", "rcp85"]
3536
gcm = ["ACCESS1-0", "GFDL-CM3"]
3637
model = ["high", "low"]
3738
ssp = [f"SSP{n}" for n in range(2, 4)]
@@ -45,14 +46,14 @@ def test_parse_projection_filesys(tmp_path):
4546
os.makedirs(os.path.join(tmp_path, b, r, g, m, s))
4647

4748
out_expected = {
48-
"batch": list(chain(repeat("batch9", 16), repeat("batch6", 16))),
49-
"rcp": list(chain(repeat("rcp85", 8), repeat("rcp45", 8))) * 2,
49+
"batch": list(chain(repeat("batch6", 16), repeat("batch9", 16))),
50+
"rcp": list(chain(repeat("rcp45", 8), repeat("rcp85", 8))) * 2,
5051
"gcm": list(chain(repeat("ACCESS1-0", 4), repeat("GFDL-CM3", 4))) * 4,
5152
"model": list(chain(repeat("high", 2), repeat("low", 2))) * 8,
5253
"ssp": ["SSP2", "SSP3"] * 16,
5354
"path": [
5455
os.path.join(tmp_path, b, r, g, m, s)
55-
for b in ["batch9", "batch6"]
56+
for b in ["batch6", "batch9"]
5657
for r in rcp
5758
for g in gcm
5859
for m in model
@@ -64,11 +65,83 @@ def test_parse_projection_filesys(tmp_path):
6465
df_out_expected = pd.DataFrame(out_expected)
6566

6667
df_out_actual = _parse_projection_filesys(input_path=tmp_path)
68+
df_out_actual = df_out_actual.sort_values(
69+
by=["batch", "rcp", "gcm", "model", "ssp"]
70+
)
6771
df_out_actual.reset_index(drop=True, inplace=True)
6872

6973
pd.testing.assert_frame_equal(df_out_expected, df_out_actual)
7074

7175

76+
def test_concatenate_damage_output(tmp_path):
77+
"""
78+
Test that concatenate_damage_output correctly concatenates damages across batches and saves to a single zarr file
79+
"""
80+
d = os.path.join(tmp_path, "concatenate_in")
81+
if not os.path.exists(d):
82+
os.makedirs(d)
83+
84+
for b in ["batch" + str(i) for i in range(0, 15)]:
85+
ds_in = xr.Dataset(
86+
{
87+
"delta_rebased": (
88+
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
89+
np.full((2, 2, 2, 2, 1, 2, 2), 1).astype(object),
90+
),
91+
"histclim_rebased": (
92+
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
93+
np.full((2, 2, 2, 2, 1, 2, 2), 2),
94+
),
95+
},
96+
coords={
97+
"batch": (["batch"], [b]),
98+
"gcm": (["gcm"], np.array(["ACCESS1-0", "BNU-ESM"], dtype=object)),
99+
"model": (["model"], ["IIASA GDP", "OECD Env-Growth"]),
100+
"rcp": (["rcp"], ["rcp45", "rcp85"]),
101+
"region": (["region"], ["ZWE.test_region", "USA.test_region"]),
102+
"ssp": (["ssp"], ["SSP2", "SSP3"]),
103+
"year": (["year"], [2020, 2099]),
104+
},
105+
)
106+
107+
infile = os.path.join(d, f"test_insuffix_{b}.zarr")
108+
109+
ds_in.to_zarr(infile)
110+
111+
ds_out_expected = xr.Dataset(
112+
{
113+
"delta_rebased": (
114+
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
115+
np.full((2, 2, 2, 2, 15, 2, 2), 1),
116+
),
117+
"histclim_rebased": (
118+
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
119+
np.full((2, 2, 2, 2, 15, 2, 2), 2),
120+
),
121+
},
122+
coords={
123+
"batch": (["batch"], ["batch" + str(i) for i in range(0, 15)]),
124+
"gcm": (["gcm"], ["ACCESS1-0", "BNU-ESM"]),
125+
"model": (["model"], ["IIASA GDP", "OECD Env-Growth"]),
126+
"rcp": (["rcp"], ["rcp45", "rcp85"]),
127+
"region": (["region"], ["ZWE.test_region", "USA.test_region"]),
128+
"ssp": (["ssp"], ["SSP2", "SSP3"]),
129+
"year": (["year"], [2020, 2099]),
130+
},
131+
)
132+
133+
concatenate_damage_output(
134+
damage_dir=d,
135+
basename="test_insuffix",
136+
save_path=os.path.join(d, "concatenate.zarr"),
137+
)
138+
ds_out_actual = xr.open_zarr(os.path.join(d, "concatenate.zarr")).sel(
139+
batch=["batch" + str(i) for i in range(0, 15)]
140+
)
141+
142+
xr.testing.assert_equal(ds_out_expected, ds_out_actual)
143+
144+
72145
@pytest.fixture
73146
def labor_in_val_fixture(tmp_path):
74147
"""
@@ -697,7 +770,9 @@ def energy_in_netcdf_fixture(tmp_path):
697770
"region",
698771
"year",
699772
],
700-
np.full((1, 1, 1, 1, 1, 2, 2), 2),
773+
np.full((1, 1, 1, 1, 1, 2, 2), 2).astype(
774+
object
775+
),
701776
),
702777
},
703778
coords={
@@ -1030,11 +1105,11 @@ def test_prep_mortality_damages(
10301105
{
10311106
"delta": (
10321107
["gcm", "batch", "ssp", "rcp", "model", "year", "region"],
1033-
np.full((2, 2, 2, 2, 2, 2, 2), -0.90681089),
1108+
np.float32(np.full((2, 2, 2, 2, 2, 2, 2), -0.90681089)),
10341109
),
10351110
"histclim": (
10361111
["gcm", "batch", "ssp", "rcp", "model", "year", "region"],
1037-
np.full((2, 2, 2, 2, 2, 2, 2), 2 * 0.90681089),
1112+
np.float32(np.full((2, 2, 2, 2, 2, 2, 2), 2 * 0.90681089)),
10381113
),
10391114
},
10401115
coords={

0 commit comments

Comments
 (0)