@@ -92,16 +92,6 @@ def __post_init__(self):
9292 raise NotImplementedError (f"Unsupported mode { self .mode } for { self .__class__ .__name__ } " )
9393
9494
95- #
96- # helpers
97- #
98- def ensure_dtype (tensor : xr .DataArray , * , dtype ) -> xr .DataArray :
99- """
100- Convert array to a given datatype
101- """
102- return tensor .astype (dtype )
103-
104-
10595#
10696# Pre- and Postprocessing implementations
10797#
@@ -129,12 +119,12 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
129119
130120@dataclass
131121class Binarize (Processing ):
132- """'output = tensor > threshold' (note: returns float array) ."""
122+ """'output = tensor > threshold'."""
133123
134124 threshold : float = MISSING # make dataclass inheritance work for py<3.10 by using an explicit MISSING value.
135125
136126 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
137- return ensure_dtype ( tensor > self .threshold , dtype = "float32" )
127+ return tensor > self .threshold
138128
139129
140130@dataclass
@@ -145,7 +135,7 @@ class Clip(Processing):
145135 max : float = MISSING
146136
147137 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
148- return ensure_dtype ( tensor .clip (min = self .min , max = self .max ), dtype = "float32" )
138+ return tensor .clip (min = self .min , max = self .max )
149139
150140
151141@dataclass
@@ -155,7 +145,7 @@ class EnsureDtype(Processing):
155145 dtype : str = MISSING
156146
157147 def apply (self , tensor : xr .DataArray ) -> xr .DataArray :
158- return ensure_dtype ( tensor , dtype = self .dtype )
148+ return tensor . astype ( self .dtype )
159149
160150
161151@dataclass
@@ -175,7 +165,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
175165 gain = self .gain
176166 offset = self .offset
177167
178- return ensure_dtype ( tensor * gain + offset , dtype = "float32" )
168+ return tensor * gain + offset
179169
180170 def __post_init__ (self ):
181171 super ().__post_init__ ()
@@ -210,8 +200,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
210200 ref_mean = self .get_computed_measure (self .reference_tensor , Mean (axes ), mode = self .mode )
211201 ref_std = self .get_computed_measure (self .reference_tensor , Std (axes ), mode = self .mode )
212202
213- tensor = (tensor - mean ) / (std + self .eps ) * (ref_std + self .eps ) + ref_mean
214- return ensure_dtype (tensor , dtype = "float32" )
203+ return (tensor - mean ) / (std + self .eps ) * (ref_std + self .eps ) + ref_mean
215204
216205
217206@dataclass
@@ -236,7 +225,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
236225 v_lower = self .get_computed_measure (ref_name , Percentile (self .min_percentile , axes = axes ))
237226 v_upper = self .get_computed_measure (ref_name , Percentile (self .max_percentile , axes = axes ))
238227
239- return ensure_dtype (( tensor - v_lower ) / (v_upper - v_lower + self .eps ), dtype = "float32" )
228+ return ( tensor - v_lower ) / (v_upper - v_lower + self .eps )
240229
241230 def __post_init__ (self ):
242231 super ().__post_init__ ()
@@ -281,8 +270,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
281270 else :
282271 raise ValueError (self .mode )
283272
284- tensor = (tensor - mean ) / (std + self .eps )
285- return ensure_dtype (tensor , dtype = "float32" )
273+ return (tensor - mean ) / (std + self .eps )
286274
287275
288276_KnownProcessing = TypedDict (
0 commit comments