Skip to content

Commit 053a504

Browse files
committed
printing hot fix
fixing the
1 parent c2da628 commit 053a504

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

flood_forecast/pytorch_training.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ def torch_single_train(model: PyTorchForecast,
379379
output = output[:, :, 0:multi_targets]
380380
labels = trg[:, -pred_len:, 0:multi_targets]
381381
multi_targets = False
382-
print(trg.shape)
383382
if model.params["dataset_params"]["class"] == "GeneralClassificationLoader":
384383
labels = trg
385384
elif multi_targets == 1:
@@ -552,7 +551,7 @@ def compute_validation(validation_loader: DataLoader,
552551
scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
553552
wandb.log({'epoch': epoch, val_or_test: scaled})
554553
if classification:
555-
print("Plotting classification metrics")
554+
print("Plotting test classification metrics")
556555
label_list = torch.cat(label_list)
557556
label_list = label_list[:, 0, :]
558557
mod_output1 = torch.cat(mod_output_list)[:, 0, :]

0 commit comments

Comments
 (0)