Skip to content

dwlyu/WaveCastNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WaveCastNet: An AI-enabled Wavefield Forecasting Framework for Earthquake Early Warning

Preprint Available At: https://arxiv.org/abs/2405.20516

Overview

WaveCastNet is an AI-driven framework designed to forecast high-resolution wavefields, using convolutional Long Expressive Memory (ConvLEM) within a sequence-to-sequence architecture. By capturing long-term dependencies and multi-scale patterns across space and time, it delivers rapid, accurate predictions without relying on traditional earthquake magnitude or epicenter estimations.

Figure 1: Illustration of the problem setup: (a) Synthetic data simulation area in the San Francisco Bay Area. (b) Snapshot of the viscoelastic velocity wavefield generated by a point-source earthquake.

Highlights

  • Point-source prediction with uncertainty estimation
    • Generate 100-second prediction from 5.6-second input
  • Robust evaluation under various noise tests
    • Empirical noise
    • Gaussian noise
    • Arrival latency noise
  • Generalization to finite-fault settings
  • Comparative study against transformer-based models
  • Zero-shot generalization from synthetic data to real-world examples

Data Availability

The wavefield data and pretrained checkpoints can be accessed via Google Drive:

  • all_input.npy: Training data from 300 events (80%).
  • {all_validation.npy, all_test.npy}: Validation data from 45 events and test data from 30 events (20%).
  • mean.npy: Per-pixel mean of the training data for normalization.
  • std.npy: Per-pixel standard deviation of the training data for normalization.
  • best_lem_dense_.pt: Pretrained checkpoint for dense input.
  • best_lem_irr_mask_shakealert_.pt: Pretrained checkpoint for models with sparse, irregular input and random encoder masking enabled.

All data has been preprocessed, normalized per pixel, and cropped to remove boundary artifacts. Each sequence has a shape of $3 \times 344 \times 224 \times 461$ ($\text{Channel} \times H \times W \times \text{Seq len}$).

To load the pretrained models, first navigate to the model directory and download the required checkpoints from google drive:

cd src/models_earthauqkes

Then, load the corresponding models for dense and sparse input sampling scenarios:

from AEConvLEM_sparse import AEConvLEM_sparse
from AEConvLEM_dense import AEConvLEM_dense

all_sample = len(np.load('filtered_coord.npy'))
real_sample = len(np.load('shakealert_coords.npy'))
mask_ratio = 1. - real_sample/all_sample

sparse_model = AEConvLEM_sparse(dt=1, num_channels=3, num_kernels=144, 
            kernel_size=(3, 3), padding=(1, 1), activation="tanh", 
            frame_size=(43,28),mask_mode=1, mask_ratio=mask_ratio)

dense_model = AEConvLEM_dense(dt=1, num_channels=3, num_kernels=144, 
            kernel_size=(3, 3), padding=(1, 1), activation="tanh", 
            frame_size=(43,28))

# Load checkpoint
state_dict_sprase = torch.load("checkpoints/best_lem_irr_mask_shakealert_.pt", map_location=torch.device('cuda'))
state_dict_dense = torch.load("checkpoints/best_lem_dense_.pt", map_location=torch.device('cuda'))

# # Remove 'module.' saved after data parallel
state_dict_sprase = {k.replace("module.", ""): v for k, v in state_dict_sprase.items()}
state_dict_dense = {k.replace("module.", ""): v for k, v in state_dict_dense.items()}

# # Load into model
sparse_model.load_state_dict(state_dict_sprase)
dense_model.load_state_dict(state_dict_dense)

Model Structure

Figure 2: Sequence-to-sequence architecture supporting both dense image inputs (i) and sparse vector inputs (ii) for generating high-resolution predictions (Left). Convolutional Long Expressive Memory (ConvLEM) Cell structure (Right).

Table 1: Main performance Metrics for the dense and sparse sampling point-source scenarios.

Figure 3: Examples of point-source time series visualizations.
  • Dense Sampling Uncertainty Estimation: Ensemble prediction using 50 trained models
  • Sparse Sampling Uncertainty Estimation: Ensemble prediction with 50 randomly sampled station sets based on a trained Masked Autoencoder

Figure 4: Peak Ground Velocity (PGV) and $T_{PGV}$ mean predictions ((a), (d)), absolute errors ((b), (e)), and standard deviations ((c), (f)) for dense (left) and sparse (right) sampling scenarios.

Dense and Sparse Sampling Scenarios Training

python WaveCastNet/earthquake_train.py --model LEM_dense --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4 #Dense Sampling

python WaveCastNet/earthquake_train.py --model LEM_sparse --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4 #Sparse Sampling

Uncertainty Quantification Training

python /global/homes/d/dwlyu/WaveCastNet/earthquake_train.py --model LEM_dense --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4 --training_uq 1 --load_seed 2

Finite Fault Generalization

Table 2: Systematic evaluation on higher-magnitude finite-fault earthquakes (Left). Performance comparison between seq2seq frameworks using different recurrent cells, and state-of-the-art transformers for forecasting small point-source earthquakes (Right).

Seq2Seq Ablation Studies

python earthquake_train.py --model LEM --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4

python earthquake_train.py --model LSTM --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4

python earthquake_train.py --model GRU --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4

Comparative Studies with Transformer Architectures

python earthquake_train.py --model Swin --num_kernels 144 --patch_size 3 4 4 --batch_size 64 --learning_rate 5e-4

python earthquake_train.py --model Time-s-pyramid --num_kernels 192 --patch_size 1 8 8 --batch_size 64 --learning_rate 5e-4 # Time-S-Former

python earthquake_train.py --model Time-s-plain --num_kernels 192 --patch_size 1 8 8 --batch_size 64 --learning_rate 5e-4 # Video Swin Transformer

Zero-shot Real-world Generalizaion

  • Example Event: Berkeley 2018 M4.4 Event at the depth of 12.3 km recorded by 178 stations.
  • Model: Sparse sampling model trained exclusively on synthetic point-source simulations
  • One-time prediction: Generates a full 110-second high-resolution sequence in one step.
  • Rolling predicton: Predicts the next 15.6-second segment using the current input and repeats the process for six steps.

Figure 5: Examples of real-world time series (Left): San Jose (NP.1788) and Woodside (BK.JRSC). High-resolution PGV and $T_{PGV}$ predictions (Right).

Experiments in Appendix

Train Moving MNIST

export CUDA_VISIBLE_DEVICES=4; python MovingMnist_train.py --model LEM -lr 5e-4 --num_layers 3 --width 64
export CUDA_VISIBLE_DEVICES=4; python MovingMnist_train.py --model LSTM -lr 5e-4 --num_layers 3 --width 64
export CUDA_VISIBLE_DEVICES=4; python MovingMnist_train.py --model QRNN -lr 5e-4 --num_layers 3 --width 64

Train RBC Fluid Flow

python RBC_train.py --model LEM -lr 5e-4 --width 72 --activation tanh --input_steps 50 --future_steps 50
python RBC_train.py --model LSTM -lr 5e-4 --width 72 --activation tanh --input_steps 50 --future_steps 50
python RBC_train.py --model QRNN -lr 5e-4 --width 72 --activation tanh --input_steps 50 --future_steps 50

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors