Skip to content

Commit 9633375

Browse files
Fixing hour argument across predictions
1 parent 6341c15 commit 9633375

File tree

2 files changed

+58
-37
lines changed

2 files changed

+58
-37
lines changed

gencast_fp/prediction/predict_gencast.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# gencast_fp/predict/predict.py
22
import os
3-
import dataclasses
43
import logging
4+
import dataclasses
55
import numpy as np
6+
import pandas as pd
67
import xarray
78
import haiku as hk
89
import jax
@@ -278,16 +279,26 @@ def run_predict_multiday(
278279
"""Predict multiple days' worth of rollouts.
279280
Calls run_predict for each day in the start_date and end_date range."""
280281
#
281-
start_date = np.datetime64(start_date)
282-
end_date = np.datetime64(end_date)
283-
date_range = np.arange(
284-
start_date, end_date + np.timedelta64(1, "D"), dtype="datetime64[D]"
285-
)
282+
# start_date = np.datetime64(start_date)
283+
# end_date = np.datetime64(end_date)
284+
# date_range = np.arange(
285+
# start_date, end_date + np.timedelta64(1, "D"), dtype="datetime64[D]"
286+
# )
287+
fmt = "%Y-%m-%d:%H"
288+
289+
# Parse exact hour from input
290+
start_ts = pd.to_datetime(start_date, format=fmt)
291+
end_ts = pd.to_datetime(end_date, format=fmt)
292+
293+
# Generate a date range in 12-hour increments
294+
date_range = pd.date_range(start=start_ts, end=end_ts, freq="12H")
286295

287296
for current_date in date_range:
288297

298+
289299
logging.info("======================================================")
290300
logging.info(f"Running prediction on date: {current_date}")
301+
"""
291302
out_fn = run_predict(
292303
current_date,
293304
input_dir,
@@ -300,6 +311,7 @@ def run_predict_multiday(
300311
)
301312
logging.info(f"Prediction saved to file: {out_fn}")
302313
logging.info("======================================================")
314+
"""
303315
return out_dir
304316

305317

gencast_fp/preprocess/fp_to_era5.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -80,61 +80,70 @@ def _fixAttrs(v, ds, ref_attrs=e5_attrs):
8080
if len(ds[v].shape) > 2:
8181
ds[v].attrs[a] = ref_attrs[v][a]
8282

83-
def _fixPoles_scalar(a5,ap):
83+
84+
def _fixPoles_scalar(a5, ap):
8485
"""
85-
Fix scalar variable at poles. Assumes first and last latitudinal points are +/- 90
86+
Fix scalar variable at poles. Assumes first and last
87+
latitudinal points are +/- 90
8688
in both source (ap) and target (a5). Notice that ERA-5 has latitudinal
8789
grid north-to-south while FP is south-to-north.
88-
90+
8991
Indexing is assumed to be one of the following:
9092
- a(lev,lat,lon)
9193
- a(lat,lon)
92-
94+
9395
"""
9496
shape = ap.shape
9597
if len(shape) == 3: # 3D
9698
for k in range(shape[0]):
97-
a5[k, 0,:] = ap[k,-1,:].mean() # north pole
98-
a5[k,-1,:] = ap[k, 0,:].mean() # south pole
99+
a5[k, 0, :] = ap[k,-1, :].mean() # north pole
100+
a5[k, -1, :] = ap[k, 0, :].mean() # south pole
99101
elif len(shape) == 2: # 2D
100-
a5[0,:] = ap[-1,:].mean(axis=0) # north pole
101-
a5[-1,:] = ap[ 0,:].mean(axis=0) # south pole
102+
a5[0, :] = ap[-1, :].mean(axis=0) # north pole
103+
a5[-1, :] = ap[0, :].mean(axis=0) # south pole
102104
else:
103105
print(ap.shape)
104106
raise ValueError('Invalid shape of input variable')
105-
107+
108+
106109
def _fixPoles_vector(u5, v5, up, vp):
107110
"""
108-
Fix vector variables at poles. Assumes first and last latitudinal points are +/- 90
111+
Fix vector variables at poles. Assumes first and
112+
last latitudinal points are +/- 90
109113
in both source (ap) and target (a5). Notice that ERA-5 has latitudinal grid
110114
north-to-south while FP is south-to-north.
111-
115+
112116
This is implemented with a simple linear interpolation in longitude.
113-
117+
114118
"""
115119

116120
lon_p = up.lon
117-
lon_5 = u5.longitude
121+
lon_5 = u5.longitude
118122

119123
shape = up.shape
120-
124+
121125
if len(shape) == 3: # 3D
122126
for k in range(shape[0]):
123-
u5[k, 0].data[:] = np.interp(lon_5, lon_p, up[k,-1].data[:], period=360.)
124-
v5[k, 0].data[:] = np.interp(lon_5, lon_p, vp[k,-1].data[:], period=360.)
125-
u5[k,-1].data[:] = np.interp(lon_5, lon_p, up[k, 0].data[:], period=360.)
126-
v5[k,-1].data[:] = np.interp(lon_5, lon_p, vp[k, 0].data[:], period=360.)
127-
127+
u5[k, 0].data[:] = np.interp(
128+
lon_5, lon_p, up[k,-1].data[:], period=360.)
129+
v5[k, 0].data[:] = np.interp(
130+
lon_5, lon_p, vp[k,-1].data[:], period=360.)
131+
u5[k, -1].data[:] = np.interp(
132+
lon_5, lon_p, up[k, 0].data[:], period=360.)
133+
v5[k, -1].data[:] = np.interp(
134+
lon_5, lon_p, vp[k, 0].data[:], period=360.)
135+
128136
elif len(shape) == 2: # 2D
129-
u5[ 0].data[:] = np.interp(lon_5, lon_p, up[-1].data[:], period=360.)
130-
v5[ 0].data[:] = np.interp(lon_5, lon_p, vp[-1].data[:], period=360.)
131-
u5[-1].data[:] = np.interp(lon_5, lon_p, up[ 0].data[:], period=360.)
132-
v5[-1].data[:] = np.interp(lon_5, lon_p, vp[ 0].data[:], period=360.)
133-
137+
u5[0].data[:] = np.interp(lon_5, lon_p, up[-1].data[:], period=360.)
138+
v5[0].data[:] = np.interp(lon_5, lon_p, vp[-1].data[:], period=360.)
139+
u5[-1].data[:] = np.interp(lon_5, lon_p, up[0].data[:], period=360.)
140+
v5[-1].data[:] = np.interp(lon_5, lon_p, vp[0].data[:], period=360.)
141+
134142
else:
135143
print(up.shape)
136144
raise ValueError('Invalid shape of input variable')
137145

146+
138147
def _scalar_vectors(ds):
139148
"""
140149
Given a dataset, return list of scalar and vector variables.
@@ -144,7 +153,7 @@ def _scalar_vectors(ds):
144153
# -------------------------
145154
U, V, S = {}, {}, []
146155
for v in ds.data_vars:
147-
if len(ds[v].shape) < 3:
156+
if len(ds[v].shape) < 3:
148157
continue
149158
std_name = ds[v].attrs['standard_name']
150159
if 'eastward' in std_name:
@@ -158,15 +167,14 @@ def _scalar_vectors(ds):
158167
# ------------
159168
VP = []
160169
for u_ in U:
161-
v_ = u_.replace('eastward','northward')
162-
p = (U[u_],V[v_])
170+
v_ = u_.replace('eastward', 'northward')
171+
p = (U[u_], V[v_])
163172
VP += [p,]
164-
173+
165174
return (S, VP)
166-
167-
#--
168175

169-
def _gat2s ( template, time, expid='f5295' ):
176+
177+
def _gat2s(template, time, expid='f5295'):
170178
"""
171179
Expand GrADS style templates/
172180
"""
@@ -181,6 +189,7 @@ def _gat2s ( template, time, expid='f5295' ):
181189
replace('%d2',d2).replace('%h2',h2).\
182190
replace('%n2',n2).replace('%expid',expid)
183191

192+
184193
def discover_files(time, outdir='./', expid='f5295'):
185194
"""
186195
Return dictionary with FP file names on discover, given time.

0 commit comments

Comments
 (0)