Skip to content

Commit dde3cb6

Browse files
authored
fix(training): provide more informative error when user specifies inexistent node attribute (#663)
Small fix to make the error message more informative. Instead of just telling a user that the specified attribute does not exist in the graph data, it also provides a list of the ones that are available. So for instance, if one mistakenly specified "cutout_mask" instead of "cutout", they will know right away what the problem is.
1 parent 04c5321 commit dde3cb6

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

training/src/anemoi/training/losses/scalers/node_attributes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ def __init__(
9494
norm=norm,
9595
**kwargs,
9696
)
97-
if self.scaling_mask_attribute_name not in self.nodes:
98-
error_msg = f"scaling_mask_attribute_name {self.scaling_mask_attribute_name} not found in graph_object"
97+
if self.scaling_mask_attribute_name not in self.nodes.node_attrs():
98+
error_msg = f"{self.__class__.__module__}.{self.__class__.__name__}: "
99+
error_msg += f"scaling_mask_attribute_name '{self.scaling_mask_attribute_name}' not found in graph_data - "
100+
avail_masks = [k for k, v in self.nodes.items() if getattr(v, "dtype", None) == torch.bool]
101+
error_msg += f"available boolean node attributes are: {avail_masks}"
99102
raise KeyError(error_msg)
100103

101104
def reweight_attribute_values(self, values: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)