forked from GeophyAI/seistorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfwi_classic.py
More file actions
64 lines (61 loc) · 2.23 KB
/
fwi_classic.py
File metadata and controls
64 lines (61 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch, tqdm
import numpy as np
import matplotlib.pyplot as plt
from utils import *
from configure import *
torch.cuda.cudnn_enabled = True
torch.backends.cudnn.benchmark = True
torch.manual_seed(seed)
# Load velocity
vel = np.load("../../models/marmousi_model/linear_vp.npy")
true = np.load("../../models/marmousi_model/true_vp.npy")
true = true[:, expand:-expand][::model_scale, ::model_scale]
vel = vel[:, expand:-expand][::model_scale, ::model_scale]
vel = np.pad(vel, ((pmln, pmln), (pmln, pmln)), mode="edge")
pmlc = generate_pml_coefficients_2d(vel.shape, N=pmln, multiple=False)
vel = torch.from_numpy(vel).float().to(dev)
vel.requires_grad = True
domain = vel.shape
nz, nx = domain
# load wave
wave = ricker(np.arange(nt) * dt-delay*dt, f=fm)
# Loss func
l2loss = torch.nn.MSELoss()
# Optimizer
opt = torch.optim.Adam([vel], lr=lr_vel)
# Geometry
src_x = np.arange(pmln, nx-pmln, srcx_step)
src_z = np.ones_like(src_x)*srcz
sources = [[src_x, src_z] for src_x, src_z in zip(src_x.tolist(), src_z.tolist())]
# load observed data
obs = np.load("obs.npy")
obs = torch.from_numpy(obs).float().to(dev)
LOSS = []
MERROR = []
kwargs = dict(wave=wave, b=pmlc, src_list = np.array(sources), domain=domain, dt=dt, h=dh, dev=dev, recz=recz, pmln=pmln)
kwargs_imshow = dict(vmin=vmin, vmax=vmax, aspect='auto', cmap='seismic', extent=[0, (nx-2*pmln)*dh, (nz-2*pmln)*dh, 0])
for epoch in tqdm.trange(EPOCHS):
# Select part of the data
rand_shots = np.random.randint(0, len(sources), size=batch_size).tolist()
kwargs.update(dict(src_list=np.array(sources)[rand_shots]))
syn = forward(c=vel, **kwargs)
# Loss
loss = l2loss(syn, obs[rand_shots])
opt.zero_grad()
loss.backward()
if reset_water:
vel.grad[:water_grid] = 0.
opt.step()
LOSS.append(loss.item())
inverted = vel.cpu().detach().numpy()[pmln:-pmln, pmln:-pmln]
MERROR.append(np.sum((true - inverted)**2))
if epoch % show_every == 0:
print(f"Epoch: {epoch}, Loss: {loss.item()}")
plt.imshow(inverted, **kwargs_imshow)
plt.show()
plt.plot(LOSS)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
np.save("model_error_by_classic.npy", np.array(MERROR))
np.save("inverted_by_classic.npy", vel.cpu().detach().numpy()[pmln:-pmln, pmln:-pmln])