Skip to content

Commit 37753d2

Browse files
committed
update nugget
1 parent 064caf9 commit 37753d2

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

gstatsim.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -605,19 +605,27 @@ def covar(effective_lag, sill, nug, vtype, s=None):
605605
c : numpy.ndarray
606606
covariance
607607
"""
608+
609+
structured_sill = sill - nug # C_struct = C0 - b
608610

609611
if vtype.lower() == 'exponential':
610-
c = (sill - nug)*np.exp(-3 * effective_lag)
612+
c = structured_sill*np.exp(-3 * effective_lag)
611613
elif vtype.lower() == 'gaussian':
612-
c = (sill - nug)*np.exp(-3 * np.square(effective_lag))
614+
c = structured_sill*np.exp(-3 * np.square(effective_lag))
613615
elif vtype.lower() == 'spherical':
614-
c = sill - nug - 1.5 * effective_lag + 0.5 * np.power(effective_lag, 3)
615-
c[effective_lag > 1] = sill - 1
616+
rho = np.zeros_like(effective_lag)
617+
mask = effective_lag <= 1.0
618+
h = effective_lag[mask]
619+
rho[mask] = 1.0 - 1.5*h + 0.5*np.power(h, 3)
620+
c = structured_sill * rho
616621
elif vtype.lower() == 'matern':
617-
scale = 0.45246434*np.exp(-0.70449189*s)+1.7863836
618-
effective_lag[effective_lag==0.0] = 1e-8
619-
c = (sill-nug)*2/gamma(s)*np.power(scale*effective_lag*np.sqrt(s), s)*kv(s, 2*scale*effective_lag*np.sqrt(s))
620-
c[np.isnan(c)] = sill-nug
622+
if s is None:
623+
raise ValueError("smoothness s must be specified for Matern covariance")
624+
scale = 0.45246434*np.exp(-0.70449189*s) + 1.7863836
625+
eff = np.array(effective_lag, copy=True)
626+
eff[eff == 0.0] = 1e-8
627+
c = structured_sill*2/gamma(s)*np.power(scale*eff*np.sqrt(s), s)*kv(s, 2*scale*eff*np.sqrt(s))
628+
c[np.isnan(c)] = structured_sill # ρ(0) = 1 → C(0-) = structured_sill
621629
else:
622630
raise AttributeError(f"vtype must be 'Exponential', 'Gaussian', 'Spherical', or Matern")
623631
return c
@@ -656,6 +664,7 @@ def make_covariance_matrix(coord, vario, rotation_matrix):
656664
mat = np.matmul(coord, rotation_matrix)
657665
effective_lag = pairwise_distances(mat,mat)
658666
covariance_matrix = Covariance.covar(effective_lag, sill, nug, vtype, s=s)
667+
np.fill_diagonal(covariance_matrix, sill)
659668

660669
return covariance_matrix
661670

@@ -787,7 +796,8 @@ def skrige(prediction_grid, df, xx, yy, zz, num_points, vario, radius, quiet=Fal
787796

788797
est_sk[z] = mean_1 + (np.sum(k_weights*(norm_data_val[:] - mean_1)))
789798
var_sk[z] = var_1 - np.sum(k_weights*covariance_array)
790-
var_sk[var_sk < 0] = 0
799+
if var_sk[z] < 0:
800+
var_sk[z] = 0
791801
else:
792802
est_sk[z] = df['Z'].values[np.where(test_idx==2)[0][0]]
793803
var_sk[z] = 0
@@ -879,7 +889,8 @@ def okrige(prediction_grid, df, xx, yy, zz, num_points, vario, radius, quiet=Fal
879889

880890
est_ok[z] = local_mean + np.sum(k_weights[0:new_num_pts]*(norm_data_val[:] - local_mean))
881891
var_ok[z] = var_1 - np.sum(k_weights[0:new_num_pts]*covariance_array[0:new_num_pts])
882-
var_ok[var_ok < 0] = 0
892+
if var_ok[z] < 0:
893+
var_ok[z] = 0
883894
else:
884895
est_ok[z] = df['Z'].values[np.where(test_idx==2)[0][0]]
885896
var_ok[z] = 0
@@ -948,7 +959,7 @@ def skrige_sgs(prediction_grid, df, xx, yy, zz, num_points, vario, radius, seed=
948959
for idx, predxy in enumerate(tqdm(prediction_grid, position=0, leave=True, disable=quiet)):
949960
z = xyindex[idx]
950961
test_idx = np.sum(prediction_grid[z]==df[['X', 'Y']].values, axis=1)
951-
if np.sum(test_idx==2)==0:
962+
if np.sum(test_idx==2)==0 or vario[1] > 0:
952963

953964
# get nearest neighbors
954965
nearest = NearestNeighbor.nearest_neighbor_search(radius, num_points,
@@ -1044,7 +1055,7 @@ def okrige_sgs(prediction_grid, df, xx, yy, zz, num_points, vario, radius, seed=
10441055
for idx, predxy in enumerate(tqdm(prediction_grid, position=0, leave=True, disable=quiet)):
10451056
z = xyindex[idx]
10461057
test_idx = np.sum(prediction_grid[z]==df[['X', 'Y']].values,axis = 1)
1047-
if np.sum(test_idx==2)==0:
1058+
if np.sum(test_idx==2)==0 or vario[1] > 0:
10481059

10491060
# gather nearest neighbor points
10501061
nearest = NearestNeighbor.nearest_neighbor_search(radius, num_points,

0 commit comments

Comments
 (0)