@@ -137,8 +137,7 @@ def update(
137137 self .inactive_shapes = tuple (output .shape for output in outputs )
138138 else :
139139 self .taylor_factors = {}
140- for i , output in enumerate (outputs ):
141- features = output .to (self .taylor_factors_dtype )
140+ for i , features in enumerate (outputs ):
142141 new_factors : Dict [int , torch .Tensor ] = {0 : features }
143142 is_first_update = self .last_update_step is None
144143 if not is_first_update :
@@ -152,8 +151,8 @@ def update(
152151 prev = prev_factors .get (j )
153152 if prev is None :
154153 break
155- new_factors [j + 1 ] = (new_factors [j ] - prev .to (self . taylor_factors_dtype )) / delta_step
156- self .taylor_factors [i ] = new_factors
154+ new_factors [j + 1 ] = (new_factors [j ] - prev .to (features . dtype )) / delta_step
155+ self .taylor_factors [i ] = { order : factor . to ( self . taylor_factors_dtype ) for order , factor in new_factors . items ()}
157156
158157 self .last_update_step = self .current_step
159158
@@ -179,14 +178,15 @@ def predict(self) -> List[torch.Tensor]:
179178 if not self .taylor_factors :
180179 raise ValueError ("Taylor factors empty during prediction." )
181180 for i in range (len (self .module_dtypes )):
181+ output_dtype = self .module_dtypes [i ]
182182 taylor_factors = self .taylor_factors [i ]
183183 # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
184- output = torch .zeros_like (taylor_factors [0 ])
184+ output = torch .zeros_like (taylor_factors [0 ], dtype = output_dtype )
185185 for order , factor in taylor_factors .items ():
186186 # Note: order starts at 0
187187 coeff = (step_offset ** order ) / math .factorial (order )
188- output = output + factor * coeff
189- outputs .append (output . to ( self . module_dtypes [ i ]) )
188+ output = output + factor . to ( output_dtype ) * coeff
189+ outputs .append (output )
190190
191191 return outputs
192192
0 commit comments