From 639993c65f79a6044693e23712f7e4b316c0f0cd Mon Sep 17 00:00:00 2001 From: ms802x <39892096+ms802x@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:27:05 +0300 Subject: [PATCH] Update trainer.py Fixed scale_schedule. It was a regular list, converted to numpy array before .prod() --- trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer.py b/trainer.py index e0daf79..c48189f 100644 --- a/trainer.py +++ b/trainer.py @@ -204,7 +204,7 @@ def train_step( if self.reweight_loss_by_scale: lw = [] - last_scale_area = np.sqrt(scale_schedule[-1].prod()) + last_scale_area = np.sqrt(scale_schedule[-1]).prod() for (pt, ph, pw) in scale_schedule[:training_scales]: this_scale_area = np.sqrt(pt * ph * pw) lw.extend([last_scale_area / this_scale_area for _ in range(pt * ph * pw)])