-
Notifications
You must be signed in to change notification settings - Fork 147
Open
Description
Hello, I was testing the baseline model checkpoints and I'm not getting close to the reported numbers for some models. For example, viscoelastic_instability, FNO and TFNO have similar metrics compared to the reported value, but both UNets are very different. Is there anything wrong in the way I'm evaluating these models? Attached a minimal code snippet below
model: polymathic-ai/FNO-viscoelastic_instability
Mean Normalized VRMSE: 0.712697 (Std Dev: 0.033637)
Mean Denormalized VRMSE: 0.706040 (Std Dev: 0.033151)
Reported Number: 0.7212
model: polymathic-ai/TFNO-viscoelastic_instability
Mean Normalized VRMSE: 0.708173 (Std Dev: 0.030011)
Mean Denormalized VRMSE: 0.702301 (Std Dev: 0.028892)
Reported Number: 0.7102
model: polymathic-ai/UNetConvNext-viscoelastic_instability
Mean Normalized VRMSE: 0.946716 (Std Dev: 0.034722)
Mean Denormalized VRMSE: 0.937711 (Std Dev: 0.029086)
Reported Number: 0.2499
model: polymathic-ai/UNetClassic-viscoelastic_instability
Mean Normalized VRMSE: 3.949996 (Std Dev: 1.488076)
Mean Denormalized VRMSE: 3.870942 (Std Dev: 1.455580)
Reported Number: 0.4185
Here's the code I'm using:
import numpy as np
import torch
from einops import rearrange
from tqdm import tqdm
import random
from the_well.benchmark.metrics import VRMSE
from the_well.data import WellDataset
from the_well.utils.download import well_download
from the_well.benchmark.models import FNO, UNetConvNext, UNetClassic, TFNO
from the_well.data.normalization import ZScoreNormalization
base_path = '/data/the_well/datasets'
device = torch.device("cuda")
dataset = WellDataset(
well_base_path=base_path,
well_dataset_name="viscoelastic_instability",
well_split_name="test",
n_steps_input=4,
n_steps_output=1,
use_normalization=True,
normalization_type=ZScoreNormalization,
)
model_path = 'polymathic-ai/UNetClassic-viscoelastic_instability'
model = UNetClassic.from_pretrained(model_path).to(device)
model.eval()
n_samples_per_batch = 8
n_batches = 100
total_samples = n_samples_per_batch * n_batches
all_batch_mean_normalized_errors = []
all_batch_mean_denormalized_errors = []
for batch_idx in tqdm(range(n_batches), desc="Evaluating Batches"):
input_batches = []
output_batches = []
for _ in range(n_samples_per_batch):
random_idx = random.randrange(len(dataset))
sample = dataset[random_idx]
input_batches.append(sample['input_fields'].to(device))
output_batches.append(sample['output_fields'].to(device))
input_batch = torch.stack(input_batches) # Shape: (B, Ti, Lx, Ly, F)
output_batch = torch.stack(output_batches) # Shape: (B, To, Lx, Ly, F)
input_batch = rearrange(input_batch, "B Ti Lx Ly F -> B (Ti F) Lx Ly")
with torch.no_grad():
pred_batch = model(input_batch) # Shape: (B, (To*F), Lx, Ly)
pred_batch = rearrange(pred_batch, "B (Tp F) Lx Ly -> B Tp Lx Ly F", Tp=dataset.n_steps_output) # Shape: (B, Tp, Lx, Ly, F)
normalized_errors_per_sample = VRMSE.eval(pred_batch, output_batch, dataset.metadata)
mean_normalized_err = normalized_errors_per_sample.mean().item()
all_batch_mean_normalized_errors.append(mean_normalized_err)
denormalized_pred_batch = dataset.norm.denormalize_flattened(pred_batch, mode="variable")
denormalized_output_batch = dataset.norm.denormalize_flattened(output_batch, mode="variable")
denormalized_errors_per_sample = VRMSE.eval(denormalized_pred_batch, denormalized_output_batch, dataset.metadata)
mean_denormalized_err = denormalized_errors_per_sample.mean().item()
all_batch_mean_denormalized_errors.append(mean_denormalized_err)
final_mean_normalized = np.mean(all_batch_mean_normalized_errors)
final_std_normalized = np.std(all_batch_mean_normalized_errors)
final_mean_denormalized = np.mean(all_batch_mean_denormalized_errors)
final_std_denormalized = np.std(all_batch_mean_denormalized_errors)
print("\n--- Overall Evaluation Results ---")
print(f"Total samples evaluated: {total_samples} ({n_batches} batches of {n_samples_per_batch} samples)")
print(f"Mean Normalized VRMSE: {final_mean_normalized:.6f} (Std Dev: {final_std_normalized:.6f})")
print(f"Mean Denormalized VRMSE: {final_mean_denormalized:.6f} (Std Dev: {final_std_denormalized:.6f})")Metadata
Metadata
Assignees
Labels
No labels