forked from GeophyAI/seistorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcal_adj.py
More file actions
68 lines (59 loc) · 2.22 KB
/
cal_adj.py
File metadata and controls
68 lines (59 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys, os, h5py
sys.path.append('../../../../')
from seistorch.loss import L2, LocalCoherence
os.makedirs('figures', exist_ok=True)
def to_tensor(d):
return torch.from_numpy(d).float()
def to_np(d):
return d.detach().numpy()
def freq_spectrum(d, dt, end_Freq=25):
freqs = np.fft.fftfreq(d.shape[0], dt)
amp = np.sum(np.abs(np.fft.fft(d, axis=0)), axis=(1))
freqs = freqs[:len(freqs)//2]
amp = amp[:len(amp)//2]
amp = amp[freqs<end_Freq]
freqs = freqs[freqs<end_Freq]
return freqs, amp
def read_hdf5(path, shot_no):
with h5py.File(path, 'r') as f:
data = f[f'shot_{shot_no}'][:]
return data
observed = read_hdf5('observed.hdf5', 3)
synthetic = read_hdf5('initial.hdf5', 3)
nsamples, ntraces, nchannels = observed.shape
# show the observed and synthetic data
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
vmin, vmax = np.percentile(observed, [1, 99])
kwargs = dict(vmin=vmin, vmax=vmax, cmap='gray_r', aspect='auto')
axes[0].imshow(observed, **kwargs)
axes[0].set_title('Observed data')
axes[1].imshow(synthetic, **kwargs)
axes[1].set_title('Initial data')
plt.tight_layout()
plt.savefig('figures/Profiles.png', dpi=300, bbox_inches='tight')
plt.show()
observed = torch.from_numpy(observed).float().unsqueeze(0)
synthetic = torch.from_numpy(synthetic).float().unsqueeze(0)
synthetic.requires_grad = True
l2_criterion = L2()
l2_loss = l2_criterion(observed,synthetic)
lc_criterion = LocalCoherence(wt=51, wx=31, sigma_hx=21.0, sigma_tau=11.0)
lc_loss = lc_criterion(observed, synthetic)
adj_l2 = torch.autograd.grad(l2_loss, synthetic, create_graph=True)[0]
adj_lc = torch.autograd.grad(lc_loss, synthetic, create_graph=True)[0]
adj_l2 = adj_l2.detach().numpy().squeeze()
adj_lc = adj_lc.detach().numpy().squeeze()
# show the adjoint fields
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
vmin, vmax = np.percentile(adj_l2, [1, 99])
kwargs = dict(vmin=vmin, vmax=vmax, cmap='gray_r', aspect='auto')
axes[0].imshow(adj_l2, **kwargs)
axes[0].set_title('Adj by L2')
axes[1].imshow(adj_lc, **kwargs)
axes[1].set_title('Adj by Local coherence')
plt.tight_layout()
plt.savefig('figures/Adjoint_sources.png', dpi=300, bbox_inches='tight')
plt.show()