Skip to content

Commit 01b7034

Browse files
JPXKQXpre-commit-ci[bot]anaprietonem
authored
fix(graphs,normalisation): add assert when dividing by 0 (#676)
## Description <!-- What issue or task does this change relate to? --> Raise an error if you try to divide by zero. Example: Using a graph data -> data with `KNNEdges(1)`. ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent fdce0f6 commit 01b7034

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

graphs/src/anemoi/graphs/normalise.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,50 @@ class NormaliserMixin:
2121
Supported normalisation methods: None, 'l1', 'l2', 'unit-max', 'unit-range', 'unit-std'.
2222
"""
2323

24-
def compute_nongrouped_statistics(self, values: torch.Tensor, *_args) -> tuple[torch.Tensor, ...]:
24+
def compute_nongrouped_statistics(self, values: torch.Tensor, *_args) -> tuple[float, ...]:
2525
if self.norm == "l1":
26-
return (torch.sum(values),)
26+
statistics = (torch.sum(values),)
2727

2828
elif self.norm == "l2":
29-
return (torch.norm(values),)
29+
statistics = (torch.norm(values),)
3030

3131
elif self.norm == "unit-max":
32-
return (torch.amax(values),)
32+
statistics = (torch.amax(values),)
3333

3434
elif self.norm == "unit-range":
35-
return torch.amin(values), torch.amax(values)
35+
statistics = torch.amin(values), torch.amax(values)
3636

3737
elif self.norm == "unit-std":
3838
std = torch.std(values)
3939
if std == 0:
4040
LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} values is 0. Normalisation is skipped.")
4141
return (1,)
42-
return (std,)
42+
43+
statistics = (std,)
44+
45+
assert (
46+
statistics[-1] != 0
47+
), f"Normalisation by zero encountered in {self.__class__.__name__} with norm '{self.norm}'."
48+
return statistics
4349

4450
def compute_grouped_statistics(
4551
self, values: torch.Tensor, index: torch.Tensor, num_groups: int, dtype, device
4652
) -> tuple[torch.Tensor, ...]:
4753
if self.norm == "l1":
4854
group_sum = torch.zeros(num_groups, values.shape[1], dtype=dtype, device=device)
4955
group_sum = group_sum.index_add(0, index, values)
50-
return (group_sum[index],)
56+
group_statistics = (group_sum[index],)
5157

5258
elif self.norm == "l2":
5359
group_sq = torch.zeros(num_groups, values.shape[1], dtype=dtype, device=device)
5460
group_sq = group_sq.index_add(0, index, values**2)
5561
group_norm = torch.sqrt(group_sq)
56-
return (group_norm[index],)
62+
group_statistics = (group_norm[index],)
5763

5864
elif self.norm == "unit-max":
5965
group_max = torch.full((num_groups, values.shape[1]), float("-inf"), dtype=dtype, device=device)
6066
group_max = group_max.index_reduce(0, index, values, reduce="amax")
61-
return (group_max[index],)
67+
group_statistics = (group_max[index],)
6268

6369
elif self.norm == "unit-range":
6470
group_min = torch.full((num_groups, values.shape[1]), float("inf"), dtype=dtype, device=device)
@@ -67,7 +73,7 @@ def compute_grouped_statistics(
6773
group_max = group_max.index_reduce(0, index, values, reduce="amax")
6874
denom = group_max - group_min
6975
denom[denom == 0] = 1 # avoid division by zero
70-
return group_min[index], denom[index]
76+
group_statistics = group_min[index], denom[index]
7177

7278
elif self.norm == "unit-std":
7379
# Compute mean
@@ -86,7 +92,12 @@ def compute_grouped_statistics(
8692
group_std = torch.sqrt(group_var)
8793
# Avoid division by zero
8894
group_std[group_std == 0] = 1
89-
return (group_std[index],)
95+
group_statistics = (group_std[index],)
96+
97+
assert torch.all(
98+
group_statistics[-1] != 0
99+
), f"Normalisation by zero encountered in {self.__class__.__name__} with norm '{self.norm}'."
100+
return group_statistics
90101

91102
def normalise(self, values: torch.Tensor, *args) -> torch.Tensor:
92103
"""Normalise the given values.

0 commit comments

Comments
 (0)