Skip to content

Commit af62ec0

Browse files
committed
fix(exp):batch_size warning
1 parent 431f976 commit af62ec0

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

exp/exp_main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,11 @@ def vali(
161161
for i, batch in enumerate(vali_loader):
162162
# warn if the size does not match
163163
if batch[next(iter(batch))].shape[0] != self.configs.batch_size and current_epoch == 0:
164-
logger.warning(f"Batch No.{i} of total {len(vali_loader)} has actual batch_size={batch[next(iter(batch))].shape[0]}, which is not the same as --batch_size={self.configs.batch_size}")
164+
logger.warning(f"Exp_Main.vali(): Batch No.{i} out of [0~{len(vali_loader) - 1}] has an actual batch_size={batch[next(iter(batch))].shape[0]}, which is not the same as --batch_size={self.configs.batch_size}")
165165
if "y_mask" in batch.keys():
166166
if torch.sum(batch["y_mask"]).item() == 0:
167167
if current_epoch == 0:
168-
logger.warning(f"Batch No.{i} of total {len(vali_loader)} has no evaluation point (inferred from y_mask), thus skipping")
168+
logger.warning(f"Exp_Main.vali(): Batch No.{i} out of [0~{len(vali_loader) - 1}] has no evaluation point (inferred from y_mask), thus skipping")
169169
continue
170170
if not self.configs.use_multi_gpu:
171171
batch = {k: v.to(f"cuda:{self.configs.gpu_id}") for k, v in batch.items()}
@@ -237,11 +237,11 @@ def train(self) -> None:
237237
for i, batch in enumerate(train_loader):
238238
# warn if the size does not match
239239
if batch[next(iter(batch))].shape[0] != self.configs.batch_size and epoch == 0:
240-
logger.warning(f"Batch No.{i} of total {len(train_loader)} has actual batch_size={batch[next(iter(batch))].shape[0]}, which is not the same as --batch_size={self.configs.batch_size}")
240+
logger.warning(f"Exp_Main.train(): Batch No.{i} out of [0~{len(train_loader) - 1}] has an actual batch_size={batch[next(iter(batch))].shape[0]}, which is not the same as --batch_size={self.configs.batch_size}")
241241
if "y_mask" in batch.keys():
242242
if torch.sum(batch["y_mask"]).item() == 0:
243243
if epoch == 0:
244-
logger.warning(f"Batch No.{i} of total {len(train_loader)} has no evaluation point (inferred from y_mask), thus skipping")
244+
logger.warning(f"Exp_Main.train(): Batch No.{i} out of [0~{len(train_loader) - 1}] has no evaluation point (inferred from y_mask), thus skipping")
245245
continue
246246
model_optim.zero_grad()
247247
if not self.configs.use_multi_gpu:
@@ -534,7 +534,7 @@ def test(self) -> None:
534534
continue
535535
# warn if the size does not match
536536
if batch[next(iter(batch))].shape[0] != self.configs.batch_size:
537-
logger.warning(f"Batch No.{i} of total {len(test_loader)} has actual batch_size={batch[next(iter(batch))].shape[0]}, which is not the same as --batch_size={self.configs.batch_size}")
537+
logger.warning(f"Exp_Main.test(): Batch No.{i} out of [0~{len(test_loader) - 1}] has an actual batch_size={batch[next(iter(batch))].shape[0]}, which is not the same as --batch_size={self.configs.batch_size}")
538538
# continue
539539
if not self.configs.use_multi_gpu:
540540
batch = {k: v.to(f"cuda:{self.configs.gpu_id}") for k, v in batch.items()}

0 commit comments

Comments
 (0)