diff --git a/trainer.py b/trainer.py index e0daf79..c48189f 100644 --- a/trainer.py +++ b/trainer.py @@ -204,7 +204,7 @@ def train_step( if self.reweight_loss_by_scale: lw = [] - last_scale_area = np.sqrt(scale_schedule[-1].prod()) + last_scale_area = np.sqrt(scale_schedule[-1]).prod() for (pt, ph, pw) in scale_schedule[:training_scales]: this_scale_area = np.sqrt(pt * ph * pw) lw.extend([last_scale_area / this_scale_area for _ in range(pt * ph * pw)])