Skip to content

Inconsistent Pre-trained Baseline Performance #49

@ArshKA

Description

@ArshKA

Hello, I was testing the baseline model checkpoints and I'm not getting close to the reported numbers for some models. For example, viscoelastic_instability, FNO and TFNO have similar metrics compared to the reported value, but both UNets are very different. Is there anything wrong in the way I'm evaluating these models? Attached a minimal code snippet below

model: polymathic-ai/FNO-viscoelastic_instability
Mean Normalized VRMSE: 0.712697 (Std Dev: 0.033637)
Mean Denormalized VRMSE: 0.706040 (Std Dev: 0.033151)
Reported Number: 0.7212

model: polymathic-ai/TFNO-viscoelastic_instability
Mean Normalized VRMSE: 0.708173 (Std Dev: 0.030011)
Mean Denormalized VRMSE: 0.702301 (Std Dev: 0.028892)
Reported Number:  0.7102

model: polymathic-ai/UNetConvNext-viscoelastic_instability
Mean Normalized VRMSE: 0.946716 (Std Dev: 0.034722)
Mean Denormalized VRMSE: 0.937711 (Std Dev: 0.029086)
Reported Number:  0.2499

model: polymathic-ai/UNetClassic-viscoelastic_instability
Mean Normalized VRMSE: 3.949996 (Std Dev: 1.488076)
Mean Denormalized VRMSE: 3.870942 (Std Dev: 1.455580)
Reported Number: 0.4185

Here's the code I'm using:

import numpy as np
import torch
from einops import rearrange
from tqdm import tqdm
import random

from the_well.benchmark.metrics import VRMSE
from the_well.data import WellDataset
from the_well.utils.download import well_download
from the_well.benchmark.models import FNO, UNetConvNext, UNetClassic, TFNO
from the_well.data.normalization import ZScoreNormalization

base_path = '/data/the_well/datasets'
device = torch.device("cuda")
dataset = WellDataset(
    well_base_path=base_path,
    well_dataset_name="viscoelastic_instability",
    well_split_name="test",
    n_steps_input=4,
    n_steps_output=1,
    use_normalization=True,
    normalization_type=ZScoreNormalization,
)

model_path = 'polymathic-ai/UNetClassic-viscoelastic_instability'
model = UNetClassic.from_pretrained(model_path).to(device)
model.eval()

n_samples_per_batch = 8
n_batches = 100
total_samples = n_samples_per_batch * n_batches

all_batch_mean_normalized_errors = []
all_batch_mean_denormalized_errors = []


for batch_idx in tqdm(range(n_batches), desc="Evaluating Batches"):
    input_batches = []
    output_batches = []
    for _ in range(n_samples_per_batch):
        random_idx = random.randrange(len(dataset))
        sample = dataset[random_idx]
        input_batches.append(sample['input_fields'].to(device))
        output_batches.append(sample['output_fields'].to(device))

    input_batch = torch.stack(input_batches)  # Shape: (B, Ti, Lx, Ly, F)
    output_batch = torch.stack(output_batches) # Shape: (B, To, Lx, Ly, F)

    input_batch = rearrange(input_batch, "B Ti Lx Ly F -> B (Ti F) Lx Ly")

    with torch.no_grad():
        pred_batch = model(input_batch) # Shape: (B, (To*F), Lx, Ly)

    pred_batch = rearrange(pred_batch, "B (Tp F) Lx Ly -> B Tp Lx Ly F", Tp=dataset.n_steps_output) # Shape: (B, Tp, Lx, Ly, F)

    normalized_errors_per_sample = VRMSE.eval(pred_batch, output_batch, dataset.metadata)
    mean_normalized_err = normalized_errors_per_sample.mean().item()
    all_batch_mean_normalized_errors.append(mean_normalized_err)

    denormalized_pred_batch = dataset.norm.denormalize_flattened(pred_batch, mode="variable")
    denormalized_output_batch = dataset.norm.denormalize_flattened(output_batch, mode="variable")
    denormalized_errors_per_sample = VRMSE.eval(denormalized_pred_batch, denormalized_output_batch, dataset.metadata)
    mean_denormalized_err = denormalized_errors_per_sample.mean().item()
    all_batch_mean_denormalized_errors.append(mean_denormalized_err)


final_mean_normalized = np.mean(all_batch_mean_normalized_errors)
final_std_normalized = np.std(all_batch_mean_normalized_errors)
final_mean_denormalized = np.mean(all_batch_mean_denormalized_errors)
final_std_denormalized = np.std(all_batch_mean_denormalized_errors)

print("\n--- Overall Evaluation Results ---")
print(f"Total samples evaluated: {total_samples} ({n_batches} batches of {n_samples_per_batch} samples)")
print(f"Mean Normalized VRMSE: {final_mean_normalized:.6f} (Std Dev: {final_std_normalized:.6f})")
print(f"Mean Denormalized VRMSE: {final_mean_denormalized:.6f} (Std Dev: {final_std_denormalized:.6f})")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions