@@ -110,11 +110,14 @@ def __init__(
110110 """
111111 if get_state is None :
112112 # Note: Numpy supports copying data between ndarrays with different dtypes.
113- # Hence, our default behavior need not coerce the ndarray represenations of
114- # tensors in `parameters` to float64 when copying over data.
113+ # Hence, our default behavior need not coerce the ndarray representations
114+ # of tensors in `parameters` to float64 when copying over data.
115115 _as_array = as_ndarray if as_array is None else as_array
116116 get_state = partial (
117- get_tensors_as_ndarray_1d , parameters , as_array = _as_array
117+ get_tensors_as_ndarray_1d ,
118+ tensors = parameters ,
119+ dtype = np_float64 ,
120+ as_array = _as_array ,
118121 )
119122
120123 if as_array is None : # per the note, do this after resolving `get_state`
@@ -154,7 +157,7 @@ def __call__(
154157 grads [index : index + size ] = self .as_array (grad .view (- 1 ))
155158 index += size
156159 except RuntimeError as e :
157- value , grads = _handle_numerical_errors (error = e , x = self .state )
160+ value , grads = _handle_numerical_errors (e , x = self .state , dtype = np_float64 )
158161
159162 return value , grads
160163
@@ -174,9 +177,9 @@ def _get_gradient_ndarray(self, fill_value: Optional[float] = None) -> ndarray:
174177
175178 size = sum (param .numel () for param in self .parameters .values ())
176179 array = (
177- np_zeros (size )
180+ np_zeros (size , dtype = np_float64 )
178181 if fill_value is None or fill_value == 0.0
179- else np_full (size , fill_value )
182+ else np_full (size , fill_value , dtype = np_float64 )
180183 )
181184 if self .persistent :
182185 self ._gradient_ndarray = array
0 commit comments