Skip to content

Commit e5f494a

Browse files
fixed the snakemake hydra behavior Fixes #47
1 parent 54092cc commit e5f494a

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

Snakefile

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,37 @@
1-
import yaml
1+
import hydra
2+
from omegaconf import OmegaConf
3+
import json
24

35
conda: "requirements.yaml"
4-
configfile: "conf/config.yaml"
56

6-
# == Load configuration ==
7+
# the workflow configuration file is orchestrated by hydra
8+
# read config with hydra
9+
with hydra.initialize(config_path="conf", version_base=None):
10+
cfg = hydra.compose(config_name="config", overrides=[])
11+
#print(OmegaConf.to_yaml(cfg))
712

8-
# dynamic config files
9-
defaults_dict = {key: value for d in config['defaults'] if isinstance(d, dict) for key, value in d.items()}
10-
shapefiles_cfg = yaml.safe_load(open(f"conf/shapefiles/{defaults_dict['shapefiles']}.yaml", 'r'))
11-
# == Define variables ==
12-
shapefile_list = shapefiles_cfg.keys()
13-
print(shapefile_list)
13+
# convert to dict of single shapefile dicts
14+
shapefiles_cfg = OmegaConf.to_container(cfg.shapefiles, resolve=True)
15+
#print(shapefiles_cfg)
16+
shapefiles_cfg_dict = {shapefile["name"]: "[" + json.dumps(shapefile).replace('"', '') + "]" for shapefile in shapefiles_cfg}
17+
#print(shapefiles_cfg_dict)
18+
shapefiles_list = list(shapefiles_cfg_dict.keys())
19+
#print(shapefiles_list)
20+
# print(f"""
21+
# python src/aggregate_climate_types.py "+shapefiles={shapefiles_cfg_dict[shapefiles_list[0]]}"
22+
# """)
23+
24+
#raise ValueError("stop here")
1425

1526
rule all:
1627
input:
1728
expand(f"data/output/climate_types_raster2polygon/climate_types_{{shapefile_name}}.parquet",
18-
shapefile_name=shapefile_list
29+
shapefile_name=shapefiles_list
1930
)
2031

2132
rule download_climate_types:
2233
output:
23-
f"data/input/climate_types/{config['climate_types_file']}"
34+
f"data/input/climate_types/{cfg.climate_types_file}"
2435
shell:
2536
"python src/download_climate_types.py"
2637

@@ -33,11 +44,17 @@ rule download_climate_types:
3344

3445
rule aggregate_climate_types:
3546
input:
36-
f"data/input/climate_types/{config['climate_types_file']}",
47+
f"data/input/climate_types/{cfg.climate_types_file}",
3748
f"data/input/shapefiles/{{shapefile_name}}/{{shapefile_name}}.shp"
3849
output:
3950
f"data/output/climate_types_raster2polygon/climate_types_{{shapefile_name}}.parquet",
4051
f"data/intermediate/climate_pcts/climate_pcts_{{shapefile_name}}.json",
4152
f"data/intermediate/climate_pcts/climate_types_{{shapefile_name}}.csv"
53+
params:
54+
shapefile_name = lambda wildcards: shapefiles_cfg_dict[wildcards.shapefile_name]
4255
shell:
43-
f"python src/aggregate_climate_types.py"
56+
(f"""
57+
echo {{wildcards.shapefile_name}}
58+
python src/aggregate_climate_types.py "+shapefiles={{params.shapefile_name}}"
59+
""")
60+
#python src/aggregate_climate_types.py "+shapefiles=[{name: CAN_ADM2, url: null, idvar: shapeID, output_idvar: id}]"

src/aggregate_climate_types.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
@hydra.main(config_path="../conf", config_name="config", version_base=None)
1515
def main(cfg):
16+
print(cfg.shapefiles)
1617
LOGGER.info("""
1718
# Extract transform, crs, nodata from raster
1819
""")
@@ -39,10 +40,10 @@ def main(cfg):
3940
)
4041

4142
# read shapefile
42-
for shapefile_name in cfg.shapefiles:
43-
LOGGER.info(f"Shapefile: {shapefile_name}")
44-
idvar = cfg.shapefiles[shapefile_name].idvar
45-
shp_path = f"data/input/shapefiles/{shapefile_name}/{shapefile_name}.shp"
43+
for shapefile in cfg.shapefiles:
44+
LOGGER.info(f"Shapefile: {shapefile.name}")
45+
idvar = shapefile.idvar
46+
shp_path = f"data/input/shapefiles/{shapefile.name}/{shapefile.name}.shp"
4647
LOGGER.info(f"Reading shapefile {shp_path}")
4748
shp = gpd.read_file(shp_path)
4849
LOGGER.info(f"Read shapefile with head\n: {shp.drop(columns='geometry').head()}")
@@ -93,7 +94,7 @@ def main(cfg):
9394
LOGGER.info(f"Fraction of locations with ties: {100 * frac_ties:.2f}%")
9495

9596
intermediate_dir = f"data/intermediate/climate_pcts"
96-
pcts_file = f"{intermediate_dir}/climate_pcts_{shapefile_name}.json"
97+
pcts_file = f"{intermediate_dir}/climate_pcts_{shapefile.name}.json"
9798
LOGGER.info(f"Saving pcts to {pcts_file}")
9899
with open(pcts_file, "w") as f:
99100
json.dump(avs, f)
@@ -109,7 +110,7 @@ def main(cfg):
109110
class_df["climate_type_long"] = class_df["climate_type_num"].map(codedict_long) # if a polygon intersects only with water then there is no assignment
110111
class_df = class_df.drop(columns="climate_type_num")
111112

112-
class_file = f"{intermediate_dir}/climate_types_{shapefile_name}.csv"
113+
class_file = f"{intermediate_dir}/climate_types_{shapefile.name}.csv"
113114
LOGGER.info(f"Saving classification to {class_file}")
114115
class_df.to_csv(class_file, index=False)
115116

@@ -125,9 +126,9 @@ def main(cfg):
125126

126127
output_df = pd.merge(class_df, output_df, on="id")
127128

128-
output_file = f"data/output/climate_types_raster2polygon/climate_types_{shapefile_name}.parquet"
129+
output_file = f"data/output/climate_types_raster2polygon/climate_types_{shapefile.name}.parquet"
129130
LOGGER.info(f"Saving output to {output_file}")
130-
output_df.rename(columns={"id": cfg.shapefiles[shapefile_name].output_idvar}, inplace=True)
131+
output_df.rename(columns={"id": shapefile.output_idvar}, inplace=True)
131132
output_df.to_parquet(output_file)
132133

133134
if __name__ == "__main__":

0 commit comments

Comments
 (0)