@@ -17,20 +17,24 @@ class NanToNum(Transform):
1717 Value to substitute wherever data is NaN.
1818 return_mask : bool, default=False
1919 If True, a mask array will be returned under a new key.
20+ mask_prefix : str, default='mask_'
21+ Prefix for the mask key in the output dictionary.
2022 """
2123
22- def __init__ (self , key : str , default_value : float = 0.0 , return_mask : bool = False ):
24+ def __init__ (self , key : str , default_value : float = 0.0 , return_mask : bool = False , mask_prefix : str = "mask" ):
2325 super ().__init__ ()
2426 self .key = key
2527 self .default_value = default_value
2628 self .return_mask = return_mask
29+ self .mask_prefix = mask_prefix
2730
2831 def get_config (self ) -> dict :
2932 return serialize (
3033 {
3134 "key" : self .key ,
3235 "default_value" : self .default_value ,
3336 "return_mask" : self .return_mask ,
37+ "mask_prefix" : self .mask_prefix ,
3438 }
3539 )
3640
@@ -39,14 +43,20 @@ def mask_key(self) -> str:
3943 """
4044 Key under which the mask will be stored in the output dictionary.
4145 """
42- return f"mask_ { self .key } " if self .key else "mask "
46+ return f"{ self .mask_prefix } _ { self .key } "
4347
4448 def forward (self , data : dict [str , any ], ** kwargs ) -> dict [str , any ]:
4549 """
4650 Forward transform: fill NaNs and optionally output mask under 'mask_<key>'.
4751 """
4852 data = data .copy ()
4953
54+ # Check if the mask key already exists in the data
55+ if self .mask_key in data .keys ():
56+ raise ValueError (
57+ f"Mask key '{ self .mask_key } ' already exists in the data. Please choose a different mask_prefix."
58+ )
59+
5060 # Identify NaNs and fill with default value
5161 mask = np .isnan (data [self .key ])
5262 data [self .key ] = np .nan_to_num (data [self .key ], copy = False , nan = self .default_value )
0 commit comments