forked from GeophyAI/seistorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcal_adj.py
More file actions
144 lines (128 loc) · 5.23 KB
/
cal_adj.py
File metadata and controls
144 lines (128 loc) · 5.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys
from scipy.signal import hilbert
sys.path.append('../../../../')
from seistorch.loss import L2, Envelope
from seistorch.transform import envelope
def to_tensor(d):
return torch.from_numpy(d).float()
def to_np(d):
return d.detach().numpy()
def freq_spectrum(d, dt, end_Freq=25):
freqs = np.fft.fftfreq(d.shape[0], dt)
amp = np.sum(np.abs(np.fft.fft(d, axis=0)), axis=(1))
freqs = freqs[:len(freqs)//2]
amp = amp[:len(amp)//2]
amp = amp[freqs<end_Freq]
freqs = freqs[freqs<end_Freq]
return freqs, amp
def read_hdf5(file, shot_no):
import h5py
with h5py.File(file, 'r') as f:
data = f[f'shot_{shot_no}'][:]
return data
obs = read_hdf5('observed.hdf5', 5)
syn = read_hdf5('initial.hdf5', 5)
nsamples, ntraces, nchannels = syn.shape
print(f"nsamples: {nsamples}, ntraces: {ntraces}")
# show the observed and synthetic data
observed = obs
synthetic = syn
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
vmin, vmax = np.percentile(observed, [1, 99])
kwargs = dict(vmin=vmin, vmax=vmax, cmap='gray_r', aspect='auto')
axes[0].imshow(observed, **kwargs)
axes[0].set_title('Observed data')
axes[1].imshow(synthetic, **kwargs)
axes[1].set_title('Initial data')
plt.tight_layout()
plt.savefig('figures/Profiles.png', dpi=300, bbox_inches='tight')
plt.show()
# Calculate the envelope of obs and syn
obs_envelope = envelope(to_tensor(observed)).numpy()
syn_envelope = envelope(to_tensor(synthetic)).numpy()
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
vmin, vmax = np.percentile(obs_envelope, [1, 99])
kwargs = dict(vmin=vmin, vmax=vmax, cmap='gray_r', aspect='auto')
axes[0].imshow(obs_envelope, **kwargs)
axes[0].set_title('Envelope of Observed data')
axes[1].imshow(syn_envelope, **kwargs)
axes[1].set_title('Envelope of Initial data')
plt.tight_layout()
plt.savefig('figures/Envelopes_Profile.png', dpi=300, bbox_inches='tight')
plt.show()
# Trace show
trace_no = 25
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].plot(observed[:, trace_no], label='obs')
axes[0].plot(obs_envelope[:, trace_no], label='env of obs')
axes[1].plot(synthetic[:, trace_no], label='syn')
axes[1].plot(syn_envelope[:, trace_no], label='env of syn')
axes[0].legend()
axes[1].legend()
plt.tight_layout()
plt.savefig('figures/Envelopes_Trace.png', dpi=300, bbox_inches='tight')
plt.show()
# calculate the envelope loss and l2 loss
envelope_diff = Envelope(method='subtract') # eq.5 in the paper
envelope_square = Envelope(method='square') # eq.6 in the paper
envelope_log = Envelope(method='log') # eq.7 in the paper (does not work)
l2_criterion = L2()
observed = torch.from_numpy(observed).float().unsqueeze(0)
synthetic = torch.from_numpy(synthetic).float().unsqueeze(0)
synthetic.requires_grad = True
loss_envelope_diff = envelope_diff(synthetic, observed)
loss_envelope_square = envelope_square(synthetic, observed)
loss_l2 = l2_criterion(synthetic, observed)
adj_envelope_diff = torch.autograd.grad(loss_envelope_diff, synthetic, create_graph=True)[0]
adj_envelope_square = torch.autograd.grad(loss_envelope_square, synthetic, create_graph=True)[0]
adj_l2 = torch.autograd.grad(loss_l2, synthetic, create_graph=True)[0]
adj_envelope_diff = adj_envelope_diff.detach().numpy().squeeze()
adj_envelope_square = adj_envelope_square.detach().numpy().squeeze()
adj_l2 = adj_l2.detach().numpy().squeeze()
# show the adjoint fields
fig, axes = plt.subplots(1, 3, figsize=(8, 3))
vmin, vmax = np.percentile(adj_l2, [1, 99])
kwargs = dict(vmin=vmin, vmax=vmax, cmap='gray_r', aspect='auto')
axes[0].imshow(adj_l2, **kwargs)
axes[0].set_title('Adj by L2')
axes[1].imshow(adj_envelope_diff, **kwargs)
axes[1].set_title('Env loss(difference)')
axes[2].imshow(adj_envelope_square, **kwargs)
axes[2].set_title('Env loss(square)')
plt.tight_layout()
plt.savefig('figures/Adjoint_sources.png', dpi=300, bbox_inches='tight')
plt.show()
# Frequency spectrum
kwargs = dict(dt=0.001, end_Freq=50)
freqs, amp_l2 = freq_spectrum(adj_l2, **kwargs)
_, amp_diff = freq_spectrum(adj_envelope_diff, **kwargs)
_, amp_square = freq_spectrum(adj_envelope_square, **kwargs)
fig, ax = plt.subplots(1,1,figsize=(5,3))
ax.plot(freqs, amp_l2/amp_l2.max(), 'b', label='L2')
ax.plot(freqs, amp_diff/amp_diff.max(), 'r', label='Env diff')
ax.plot(freqs, amp_square/amp_square.max(), 'g', label='Env square')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Normalized Amplitude')
ax.legend()
plt.tight_layout()
plt.savefig('figures/adj_freq_spectrum.png', dpi=300, bbox_inches='tight')
plt.show()
## Calculate the adjoint source by hand v.s. by AD
# factor1 = (syn_envelope-obs_envelope)/syn_envelope
# eq 15 in paper http://dx.doi.org/10.1016/j.jappgeo.2014.07.010
factor2 = syn_envelope**2-obs_envelope**2
fs2 = factor2*syn-hilbert(factor2*hilbert(syn, axis=0).imag, axis=0).imag
adj_envelope = adj_envelope_square
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
vmin, vmax = np.percentile(adj_envelope, [1, 99])
axes[0].imshow(adj_envelope, vmin=vmin, vmax=vmax, cmap='gray_r', aspect='auto')
axes[0].set_title('Adj cal by AD')
vmin, vmax = np.percentile(fs2, [1, 99])
axes[1].imshow(fs2, vmin=vmin, vmax=vmax, cmap='gray_r', aspect='auto')
axes[1].set_title('Adj cal by hand')
plt.tight_layout()
plt.savefig('figures/Adj_compare.png', dpi=300, bbox_inches='tight')
plt.show()