Skip to content

Commit fee2f0c

Browse files
authored
Modify DataReaderObs to get base_yyyy... from stream config (ecmwf#794)
* modify DataReaderObs to get base_yyyy... from stream config, and set it in the ctor, with default of 19700101. Use it in _setup_sample_index. Remove loading obs_id attr. Add igra.yml with example usage. * add license to igra config * update to ISO base_datetime, parse to read idx from zarr
1 parent ac734f9 commit fee2f0c

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

config/streams/igra/igra.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# (C) Copyright 2025 WeatherGenerator contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
IGRA :
11+
type : obs
12+
filenames : ['igra.zarr']
13+
base_datetime : '1750-01-01T00:00:00'
14+
loss_weight : 1.0
15+
#masking_rate_none : 0.05
16+
token_size : 64
17+
tokenize_spacetime : True
18+
max_num_targets: -1
19+
embed :
20+
net : transformer
21+
num_tokens : 1
22+
num_heads : 2
23+
dim_embed : 256
24+
num_blocks : 2
25+
embed_target_coords :
26+
net : linear
27+
dim_embed : 256
28+
target_readout :
29+
type : 'obs_value' # token or obs_value
30+
num_layers : 2
31+
num_heads : 4
32+
pred_head :
33+
ens_size : 1
34+
num_layers : 1

src/weathergen/datasets/data_reader_obs.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ def __init__(self, tw_handler: TimeWindowHandler, filename: Path, stream_info: d
3333
self.z = zarr.open(filename, mode="r")
3434
self.data = self.z["data"]
3535
self.dt = self.z["dates"] # datetime only
36-
self.hrly_index = self.z["idx_197001010000_1"]
36+
self.base_datetime = stream_info.get("base_datetime", "1970-01-01T00:00:00")
37+
format_str = "%Y-%m-%dT%H:%M:%S"
38+
self.base_datetime = datetime.datetime.strptime(str(self.base_datetime), format_str)
39+
# To read idx convert to a string, format e.g.: 197001010000
40+
base_date_str = self.base_datetime.strftime("%Y%m%d%H%M")
41+
self.hrly_index = self.z[f"idx_{base_date_str}_1"]
3742
self.colnames = self.data.attrs["colnames"]
3843

3944
data_colnames = [col for col in self.colnames if "obsvalue" in col]
@@ -63,7 +68,7 @@ def __init__(self, tw_handler: TimeWindowHandler, filename: Path, stream_info: d
6368
self.geoinfo_idx = list(range(self.coords_idx[-1] + 1, data_idx[0]))
6469
self.geoinfo_channels = [self.colnames[i] for i in self.geoinfo_idx]
6570

66-
# load additional properties (mean, var, obs_id)
71+
# load additional properties (mean, var)
6772
self._load_properties()
6873
self.mean = np.array(self.properties["means"]) # [data_idx]
6974
self.stdev = np.sqrt(np.array(self.properties["vars"])) # [data_idx])
@@ -140,19 +145,13 @@ def _setup_sample_index(self) -> None:
140145
)
141146
step_hrs = int(self.time_window_handler.t_window_step.item().total_seconds()) // 3600
142147

143-
# TODO: move to ctor
144-
base_yyyymmddhhmm = 197001010000
145-
146-
# Derive new index based on hourly backbone index
147-
format_str = "%Y%m%d%H%M%S"
148-
base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str)
149148
self.start_dt = self.time_window_handler.t_start.item()
150149
self.end_dt = self.time_window_handler.t_end.item()
151150

152151
## Calculate the number of hours between start of hourly base index
153152
# and the requested sample index
154-
diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() / 3600)
155-
diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() / 3600)
153+
diff_in_hours_start = int((self.start_dt - self.base_datetime).total_seconds() / 3600)
154+
diff_in_hours_end = int((self.end_dt - self.base_datetime).total_seconds() / 3600)
156155

157156
end_range_1 = min(diff_in_hours_end, self.hrly_index.shape[0] - 1)
158157
self.indices_start = self.hrly_index[diff_in_hours_start:end_range_1:step_hrs]
@@ -201,7 +200,6 @@ def _load_properties(self) -> None:
201200

202201
self.properties["means"] = self.data.attrs["means"]
203202
self.properties["vars"] = self.data.attrs["vars"]
204-
self.properties["obs_id"] = self.data.attrs["obs_id"]
205203

206204
@override
207205
def _get(self, idx: int, channels_idx: list[int]) -> ReaderData:

0 commit comments

Comments
 (0)