Skip to content

Commit bf87875

Browse files
committed
Bug fixes and updates
1 parent cc0f111 commit bf87875

File tree

15 files changed

+30
-30
lines changed

15 files changed

+30
-30
lines changed

graph_weather/data/IFSAnalysis_dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self, filepath: str, features: list, start_year: int = 2016, end_ye
4343
"""
4444

4545
super().__init__()
46-
assert (
47-
start_year <= end_year
48-
), f"start_year ({start_year}) cannot be greater than end_year ({end_year})."
46+
assert start_year <= end_year, (
47+
f"start_year ({start_year}) cannot be greater than end_year ({end_year})."
48+
)
4949
assert start_year >= 2016 and start_year <= 2022, "Time data range from 2016 to 2022"
5050
assert end_year >= 2016 and end_year <= 2022, "Time data range from 2016 to 2022"
5151
self.data = xr.open_zarr(filepath)

graph_weather/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Dataloaders and data processing utilities"""
22

3-
from .anemoi_dataloder import AnemoiDataset
3+
from .anemoi_dataloader import AnemoiDataset
44
from .nnja_ai import SensorDataset, collate_fn
55
from .weather_station_reader import WeatherStationReader

graph_weather/data/weather_station_reader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ def _convert_to_synopticpy(self, observations: xr.Dataset) -> Optional[Dict]:
330330
values = station_data[var_name].values
331331

332332
# Add to observations
333-
synoptic_data["STATION"][str(station)]["OBSERVATIONS"][
334-
var_name
335-
] = values.tolist()
333+
synoptic_data["STATION"][str(station)]["OBSERVATIONS"][var_name] = (
334+
values.tolist()
335+
)
336336

337337
return synoptic_data
338338

graph_weather/models/aurora/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def forward(self, points: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
4343

4444
# Normalize coordinates to [-1, 1] range
4545
normalized_points = torch.stack(
46-
[points[..., 0] / 180.0, points[..., 1] / 90.0], dim=-1 # longitude # latitude
46+
[points[..., 0] / 180.0, points[..., 1] / 90.0],
47+
dim=-1, # longitude # latitude
4748
)
4849

4950
# Separately encode coordinates and features

graph_weather/models/fengwu_ghr/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ def __init__(
105105
)
106106
)
107107
if self.res:
108-
assert (
109-
image_size is not None and scale_factor is not None
110-
), "If res=True, you must provide h, w and scale_factor"
108+
assert image_size is not None and scale_factor is not None, (
109+
"If res=True, you must provide h, w and scale_factor"
110+
)
111111
h, w = pair(image_size)
112112
s_h, s_w = pair(scale_factor)
113113
self.res_layers.append(

graph_weather/models/gencast/graph/graph_builder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,7 @@ def _init_khop_mesh_graph(self):
309309
edge_index,
310310
values=torch.ones_like(edge_index[0], dtype=torch.float32),
311311
size=(self._num_mesh_nodes, self._num_mesh_nodes),
312-
).to(
313-
self.khop_device
314-
) # cpu is more memory-efficient, why?
312+
).to(self.khop_device) # cpu is more memory-efficient, why?
315313

316314
adj_k = adj.coalesce()
317315
for _ in range(self.num_hops - 1):

graph_weather/models/weathermesh/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
self.transformer_layers = nn.ModuleList(
5151
[
5252
NeighborhoodAttention3D(
53-
dim=latent_dim, num_heads=num_heads, kernel_size=kernel_size
53+
embed_dim=latent_dim, num_heads=num_heads, kernel_size=kernel_size
5454
)
5555
for _ in range(num_transformer_layers)
5656
]

graph_weather/models/weathermesh/encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
self.transformer_layers = nn.ModuleList(
7777
[
7878
NeighborhoodAttention3D(
79-
dim=latent_dim, kernel_size=kernel_size, num_heads=num_heads
79+
embed_dim=latent_dim, kernel_size=kernel_size, num_heads=num_heads
8080
)
8181
for _ in range(num_transformer_layers)
8282
]

graph_weather/models/weathermesh/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, latent_dim, n_layers=10, kernel=(5, 7, 7), num_heads=8):
3131
self.layers = nn.ModuleList(
3232
[
3333
NeighborhoodAttention3D(
34-
dim=latent_dim,
34+
embed_dim=latent_dim,
3535
num_heads=num_heads,
3636
kernel_size=kernel,
3737
)

0 commit comments

Comments
 (0)