@@ -513,32 +513,27 @@ def _recompute_all_losses(self):
513
513
@property
514
514
def _scale (self ):
515
515
# get the output scale
516
- scale = self ._max_value - self ._min_value
517
- if isinstance (scale , np .ndarray ):
518
- scale [scale == 0 ] = 1
519
- elif scale == 0 :
520
- scale = 1
521
- return scale
516
+ return np .max (self ._max_value - self ._min_value )
522
517
523
518
def _update_range (self , new_output ):
524
519
if self ._min_value is None or self ._max_value is None :
525
520
# this is the first point, nothing to do, just set the range
526
521
self ._min_value = np .array (new_output )
527
522
self ._max_value = np .array (new_output )
528
- self ._old_scale = self ._scale
523
+ self ._old_scale = self ._scale or 1
529
524
return False
530
525
531
526
# if range in one or more directions is doubled, then update all losses
532
527
self ._min_value = np .minimum (self ._min_value , new_output )
533
528
self ._max_value = np .maximum (self ._max_value , new_output )
534
529
535
- scale_multiplier = 1 / self ._scale
530
+ scale_multiplier = 1 / ( self ._scale or 1 )
536
531
if isinstance (scale_multiplier , float ):
537
532
scale_multiplier = np .array ([scale_multiplier ], dtype = float )
538
533
539
534
# the maximum absolute value that is in the range. Because this is the
540
535
# largest number, this also has the largest absolute numerical error.
541
- max_absolute_value_in_range = np .max (np .abs ([self ._min_value , self ._max_value ]), axis = 0 )
536
+ max_absolute_value_in_range = np .max (np .abs ([self ._min_value , self ._max_value ]))
542
537
# since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
543
538
abs_err = 1e-15 * max_absolute_value_in_range
544
539
# when scaling the floats, the error gets increased.
0 commit comments