@@ -483,7 +483,7 @@ def _compute_loss(self, simplex):
483
483
484
484
# scale them to a cube with sides 1
485
485
vertices = vertices @ self ._transform
486
- values = self ._output_multiplier * values
486
+ values = self ._output_multiplier * np . array ( values )
487
487
488
488
# compute the loss on the scaled simplex
489
489
return float (self .loss_per_simplex (vertices , values ))
@@ -513,40 +513,38 @@ def _recompute_all_losses(self):
513
513
@property
514
514
def _scale (self ):
515
515
# get the output scale
516
- return np . max ( self ._max_value - self ._min_value )
516
+ return self ._max_value - self ._min_value
517
517
518
518
def _update_range (self , new_output ):
519
519
if self ._min_value is None or self ._max_value is None :
520
520
# this is the first point, nothing to do, just set the range
521
- self ._min_value = np .array (new_output )
522
- self ._max_value = np .array (new_output )
521
+ self ._min_value = np .min (new_output )
522
+ self ._max_value = np .max (new_output )
523
523
self ._old_scale = self ._scale or 1
524
524
return False
525
525
526
526
# if range in one or more directions is doubled, then update all losses
527
- self ._min_value = np . minimum (self ._min_value , new_output )
528
- self ._max_value = np . maximum (self ._max_value , new_output )
527
+ self ._min_value = min (self ._min_value , np . min ( new_output ) )
528
+ self ._max_value = max (self ._max_value , np . max ( new_output ) )
529
529
530
530
scale_multiplier = 1 / (self ._scale or 1 )
531
- if isinstance (scale_multiplier , float ):
532
- scale_multiplier = np .array ([scale_multiplier ], dtype = float )
533
531
534
532
# the maximum absolute value that is in the range. Because this is the
535
533
# largest number, this also has the largest absolute numerical error.
536
- max_absolute_value_in_range = np .max (np .abs ([self ._min_value , self ._max_value ]))
534
+ max_absolute_value_in_range = max (abs (self ._min_value ),
535
+ abs (self ._max_value ))
537
536
# since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
538
537
abs_err = 1e-15 * max_absolute_value_in_range
539
538
# when scaling the floats, the error gets increased.
540
539
scaled_err = abs_err * scale_multiplier
541
540
542
- allowed_numerical_error = 1e-2
543
-
544
541
# do not scale along the axis if the numerical error gets too big
545
- scale_multiplier [scaled_err > allowed_numerical_error ] = 1
542
+ if scaled_err > 1e-2 : # allowed_numerical_error = 1e-2
543
+ scale_multiplier = 1
546
544
547
545
self ._output_multiplier = scale_multiplier
548
546
549
- scale_factor = np . max ( self ._scale / self ._old_scale )
547
+ scale_factor = self ._scale / self ._old_scale
550
548
if scale_factor > self ._recompute_losses_factor :
551
549
self ._old_scale = self ._scale
552
550
self ._recompute_all_losses ()
0 commit comments