Preprint Available At: https://arxiv.org/abs/2405.20516
WaveCastNet is an AI-driven framework designed to forecast high-resolution wavefields, using convolutional Long Expressive Memory (ConvLEM) within a sequence-to-sequence architecture. By capturing long-term dependencies and multi-scale patterns across space and time, it delivers rapid, accurate predictions without relying on traditional earthquake magnitude or epicenter estimations.
Figure 1: Illustration of the problem setup: (a) Synthetic data simulation area in the San Francisco Bay Area. (b) Snapshot of the viscoelastic velocity wavefield generated by a point-source earthquake.- Point-source prediction with uncertainty estimation
- Generate 100-second prediction from 5.6-second input
- Robust evaluation under various noise tests
- Empirical noise
- Gaussian noise
- Arrival latency noise
- Generalization to finite-fault settings
- Comparative study against transformer-based models
- Zero-shot generalization from synthetic data to real-world examples
The wavefield data and pretrained checkpoints can be accessed via Google Drive:
all_input.npy: Training data from 300 events (80%).{all_validation.npy, all_test.npy}: Validation data from 45 events and test data from 30 events (20%).mean.npy: Per-pixel mean of the training data for normalization.std.npy: Per-pixel standard deviation of the training data for normalization.best_lem_dense_.pt: Pretrained checkpoint for dense input.best_lem_irr_mask_shakealert_.pt: Pretrained checkpoint for models with sparse, irregular input and random encoder masking enabled.
All data has been preprocessed, normalized per pixel, and cropped to remove boundary artifacts. Each sequence has a shape of
To load the pretrained models, first navigate to the model directory and download the required checkpoints from google drive:
cd src/models_earthauqkes
Then, load the corresponding models for dense and sparse input sampling scenarios:
from AEConvLEM_sparse import AEConvLEM_sparse
from AEConvLEM_dense import AEConvLEM_dense
all_sample = len(np.load('filtered_coord.npy'))
real_sample = len(np.load('shakealert_coords.npy'))
mask_ratio = 1. - real_sample/all_sample
sparse_model = AEConvLEM_sparse(dt=1, num_channels=3, num_kernels=144,
kernel_size=(3, 3), padding=(1, 1), activation="tanh",
frame_size=(43,28),mask_mode=1, mask_ratio=mask_ratio)
dense_model = AEConvLEM_dense(dt=1, num_channels=3, num_kernels=144,
kernel_size=(3, 3), padding=(1, 1), activation="tanh",
frame_size=(43,28))
# Load checkpoint
state_dict_sprase = torch.load("checkpoints/best_lem_irr_mask_shakealert_.pt", map_location=torch.device('cuda'))
state_dict_dense = torch.load("checkpoints/best_lem_dense_.pt", map_location=torch.device('cuda'))
# # Remove 'module.' saved after data parallel
state_dict_sprase = {k.replace("module.", ""): v for k, v in state_dict_sprase.items()}
state_dict_dense = {k.replace("module.", ""): v for k, v in state_dict_dense.items()}
# # Load into model
sparse_model.load_state_dict(state_dict_sprase)
dense_model.load_state_dict(state_dict_dense)
- Dense Sampling Uncertainty Estimation: Ensemble prediction using 50 trained models
- Sparse Sampling Uncertainty Estimation: Ensemble prediction with 50 randomly sampled station sets based on a trained Masked Autoencoder
Figure 4: Peak Ground Velocity (PGV) and
python WaveCastNet/earthquake_train.py --model LEM_dense --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4 #Dense Sampling
python WaveCastNet/earthquake_train.py --model LEM_sparse --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4 #Sparse Sampling
python /global/homes/d/dwlyu/WaveCastNet/earthquake_train.py --model LEM_dense --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4 --training_uq 1 --load_seed 2
python earthquake_train.py --model LEM --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4
python earthquake_train.py --model LSTM --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4
python earthquake_train.py --model GRU --num_kernels 144 --activation tanh --batch_size 64 --learning_rate 5e-4
python earthquake_train.py --model Swin --num_kernels 144 --patch_size 3 4 4 --batch_size 64 --learning_rate 5e-4
python earthquake_train.py --model Time-s-pyramid --num_kernels 192 --patch_size 1 8 8 --batch_size 64 --learning_rate 5e-4 # Time-S-Former
python earthquake_train.py --model Time-s-plain --num_kernels 192 --patch_size 1 8 8 --batch_size 64 --learning_rate 5e-4 # Video Swin Transformer
- Example Event: Berkeley 2018 M4.4 Event at the depth of 12.3 km recorded by 178 stations.
- Model: Sparse sampling model trained exclusively on synthetic point-source simulations
- One-time prediction: Generates a full 110-second high-resolution sequence in one step.
- Rolling predicton: Predicts the next 15.6-second segment using the current input and repeats the process for six steps.
Figure 5: Examples of real-world time series (Left): San Jose (NP.1788) and Woodside (BK.JRSC). High-resolution PGV and
export CUDA_VISIBLE_DEVICES=4; python MovingMnist_train.py --model LEM -lr 5e-4 --num_layers 3 --width 64
export CUDA_VISIBLE_DEVICES=4; python MovingMnist_train.py --model LSTM -lr 5e-4 --num_layers 3 --width 64
export CUDA_VISIBLE_DEVICES=4; python MovingMnist_train.py --model QRNN -lr 5e-4 --num_layers 3 --width 64
python RBC_train.py --model LEM -lr 5e-4 --width 72 --activation tanh --input_steps 50 --future_steps 50
python RBC_train.py --model LSTM -lr 5e-4 --width 72 --activation tanh --input_steps 50 --future_steps 50
python RBC_train.py --model QRNN -lr 5e-4 --width 72 --activation tanh --input_steps 50 --future_steps 50










