Skip to content

Commit 666b91b

Browse files
author
Jawad Chowdhury
committed
added fitting function variation
1 parent fbf132a commit 666b91b

File tree

2 files changed

+228
-3
lines changed

2 files changed

+228
-3
lines changed

BGlib/be/analysis/utils/be_loop.py

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from scipy.special import erf, erfinv
1919
import warnings
2020
from scipy.optimize import curve_fit
21+
from sklearn.metrics import r2_score
22+
from sklearn.gaussian_process import GaussianProcessRegressor
23+
from sklearn.gaussian_process.kernels import Matern
2124

2225
# switching32 = np.dtype([('V+', np.float32),
2326
# ('V-', np.float32),
@@ -1012,7 +1015,6 @@ def intersection(L1, L2):
10121015
except Exception as e:
10131016
print('Error: ', e)
10141017
##
1015-
from sklearn.metrics import r2_score
10161018
best_guess = init_guess_coef_vec.copy()
10171019
try:
10181020
best_guess, _ = curve_fit(loop_fit_function, vdc, pr_vec, p0=init_guess_coef_vec, maxfev=5000)
@@ -1039,6 +1041,228 @@ def intersection(L1, L2):
10391041
##
10401042
return best_guess
10411043

1044+
1045+
def generate_deepGP_guess(vdc, pr_vec, show_plots=False):
1046+
"""
1047+
Given a single unfolded loop and centroid, return the best-fit parameter guess.
1048+
We start with an initial estimate based on the loop centroid and intersection points.
1049+
Then, we refine it by running the fitting program multiple times with randomized perturbations
1050+
around the initial guess and keeping the parameters that yield the lowest fitting error.
1051+
1052+
Parameters
1053+
-----------
1054+
vdc : 1D numpy array
1055+
DC offsets
1056+
pr_vec : 1D numpy array
1057+
Piezoresponse or unfolded loop
1058+
show_plots : Boolean (Optional. Default = False)
1059+
Whether or not the plot the convex hull, centroid, intersection points
1060+
1061+
Returns
1062+
-----------------
1063+
init_guess_coef_vec : 1D Numpy array
1064+
Fit guess coefficient vector
1065+
"""
1066+
1067+
points = np.transpose(np.array([np.squeeze(vdc), pr_vec])) # [points,axis]
1068+
1069+
geom_centroid, geom_area = calculate_loop_centroid(points[:, 0], points[:, 1])
1070+
1071+
hull = ConvexHull(points)
1072+
1073+
"""
1074+
Now we need to find the intersection points on the N,S,E,W
1075+
the simplex of the complex hull is essentially a set of line equations.
1076+
We need to find the two lines (top and bottom) or (left and right) that
1077+
interect with the vertical / horizontal lines passing through the geometric centroid
1078+
"""
1079+
1080+
def find_intersection(A, B, C, D):
1081+
"""
1082+
Finds the coordinates where two line segments intersect
1083+
1084+
Parameters
1085+
------------
1086+
A, B, C, D : Tuple or 1D list or 1D numpy array
1087+
(x,y) coordinates of the points that define the two line segments AB and CD
1088+
1089+
Returns
1090+
----------
1091+
obj : None or tuple
1092+
None if not intersecting. (x,y) coordinates of intersection
1093+
"""
1094+
1095+
def ccw(A, B, C):
1096+
"""Credit - StackOverflow"""
1097+
return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])
1098+
1099+
def line(p1, p2):
1100+
"""Credit - StackOverflow"""
1101+
A = (p1[1] - p2[1])
1102+
B = (p2[0] - p1[0])
1103+
C = (p1[0] * p2[1] - p2[0] * p1[1])
1104+
return A, B, -C
1105+
1106+
def intersection(L1, L2):
1107+
"""
1108+
Finds the intersection of two lines (NOT line segments).
1109+
Credit - StackOverflow
1110+
"""
1111+
D = L1[0] * L2[1] - L1[1] * L2[0]
1112+
Dx = L1[2] * L2[1] - L1[1] * L2[2]
1113+
Dy = L1[0] * L2[2] - L1[2] * L2[0]
1114+
if D != 0:
1115+
x = Dx / D
1116+
y = Dy / D
1117+
return x, y
1118+
else:
1119+
return None
1120+
1121+
if ((ccw(A, C, D) is not ccw(B, C, D)) and (ccw(A, B, C) is not ccw(A, B, D))) is False:
1122+
return None
1123+
else:
1124+
return intersection(line(A, B), line(C, D))
1125+
1126+
# start and end coordinates of each line segment defining the convex hull
1127+
outline_1 = np.zeros((hull.simplices.shape[0], 2), dtype=float)
1128+
outline_2 = np.zeros((hull.simplices.shape[0], 2), dtype=float)
1129+
for index, pair in enumerate(hull.simplices):
1130+
outline_1[index, :] = points[pair[0]]
1131+
outline_2[index, :] = points[pair[1]]
1132+
1133+
"""Find the coordinates of the points where the vertical line through the
1134+
centroid intersects with the convex hull"""
1135+
y_intersections = []
1136+
for pair in range(outline_1.shape[0]):
1137+
x_pt = find_intersection(outline_1[pair], outline_2[pair],
1138+
[geom_centroid[0], hull.min_bound[1]],
1139+
[geom_centroid[0], hull.max_bound[1]])
1140+
if x_pt is not None:
1141+
y_intersections.append(x_pt)
1142+
1143+
'''
1144+
Find the coordinates of the points where the horizontal line through the
1145+
centroid intersects with the convex hull
1146+
'''
1147+
x_intersections = []
1148+
for pair in range(outline_1.shape[0]):
1149+
x_pt = find_intersection(outline_1[pair], outline_2[pair],
1150+
[hull.min_bound[0], geom_centroid[1]],
1151+
[hull.max_bound[0], geom_centroid[1]])
1152+
if x_pt is not None:
1153+
x_intersections.append(x_pt)
1154+
1155+
'''
1156+
Default values if not intersections can be found.
1157+
'''
1158+
if len(y_intersections) < 2:
1159+
min_y_intercept = min(pr_vec)
1160+
max_y_intercept = max(pr_vec)
1161+
else:
1162+
min_y_intercept = min(y_intersections[0][1], y_intersections[1][1])
1163+
max_y_intercept = max(y_intersections[0][1], y_intersections[1][1])
1164+
1165+
if len(x_intersections) < 2:
1166+
min_x_intercept = min(vdc) / 2.0
1167+
max_x_intercept = max(vdc) / 2.0
1168+
else:
1169+
min_x_intercept = min(x_intersections[0][0], x_intersections[1][0])
1170+
max_x_intercept = max(x_intersections[0][0], x_intersections[1][0])
1171+
1172+
# Only the first four parameters use the information from the intercepts
1173+
# a3, a4 are swapped in Stephen's figure. That was causing the branches to swap during fitting
1174+
# the a3, a4 are fixed now below:
1175+
init_guess_coef_vec = np.zeros(shape=9)
1176+
init_guess_coef_vec[0] = min_y_intercept
1177+
init_guess_coef_vec[1] = max_y_intercept - min_y_intercept
1178+
init_guess_coef_vec[2] = min_x_intercept
1179+
init_guess_coef_vec[3] = max_x_intercept
1180+
init_guess_coef_vec[4] = 0
1181+
init_guess_coef_vec[5] = 2 # 0.5
1182+
init_guess_coef_vec[6] = 2 # 0.2
1183+
init_guess_coef_vec[7] = 2 # 1.0
1184+
init_guess_coef_vec[8] = 2 # 0.2
1185+
1186+
if show_plots:
1187+
try:
1188+
fig, ax = plt.subplots()
1189+
ax.plot(points[:, 0], points[:, 1], 'o')
1190+
ax.plot(geom_centroid[0], geom_centroid[1], 'r*')
1191+
ax.plot([geom_centroid[0], geom_centroid[0]], [hull.max_bound[1], hull.min_bound[1]], 'g')
1192+
ax.plot([hull.min_bound[0], hull.max_bound[0]], [geom_centroid[1], geom_centroid[1]], 'g')
1193+
for simplex in hull.simplices:
1194+
ax.plot(points[simplex, 0], points[simplex, 1], 'k')
1195+
ax.plot(x_intersections[0][0], x_intersections[0][1], 'r*')
1196+
ax.plot(x_intersections[1][0], x_intersections[1][1], 'r*')
1197+
ax.plot(y_intersections[0][0], y_intersections[0][1], 'r*')
1198+
ax.plot(y_intersections[1][0], y_intersections[1][1], 'r*')
1199+
ax.plot(vdc, loop_fit_function(vdc, *init_guess_coef_vec))
1200+
except Exception as e:
1201+
print('Error: ', e)
1202+
##
1203+
X = []
1204+
y = []
1205+
#
1206+
best_guess = init_guess_coef_vec.copy()
1207+
best_err = np.inf
1208+
best_r2 = -np.inf
1209+
def objective(params):
1210+
"""Returns error for a given parameter vector."""
1211+
try:
1212+
guess, _ = curve_fit(loop_fit_function, vdc, pr_vec, p0=params, maxfev=5000)
1213+
pred = loop_fit_function(vdc, *guess)
1214+
err = np.sum((pr_vec - pred) ** 2)
1215+
r2 = r2_score(pr_vec, pred)
1216+
return err, r2, guess
1217+
except RuntimeError:
1218+
return np.inf, -np.inf, params
1219+
#
1220+
for _ in range(10):
1221+
p0_rand = init_guess_coef_vec * (1 + 0.3 * np.random.randn(len(init_guess_coef_vec)))
1222+
err, r2, guess = objective(p0_rand)
1223+
if np.isfinite(err):
1224+
X.append(p0_rand)
1225+
y.append(err)
1226+
if err < best_err:
1227+
best_guess, best_err, best_r2 = guess, err, r2
1228+
if best_r2 >= 0.95:
1229+
break
1230+
#
1231+
kernel = Matern(length_scale=1.0, nu=2.5)
1232+
gp = GaussianProcessRegressor(kernel=kernel, alpha=1e-6, normalize_y=True)
1233+
#
1234+
for trial_no in range(50):
1235+
if best_r2 >= 0.95:
1236+
break
1237+
X_arr = np.array(X)
1238+
y_arr = np.array(y)
1239+
y_arr = np.nan_to_num(y_arr, nan=1e6, posinf=1e6, neginf=1e6)
1240+
if len(y_arr) == 0:
1241+
X_arr = np.array([init_guess_coef_vec])
1242+
y_arr = np.array([1e6])
1243+
1244+
gp.fit(X_arr, y_arr)
1245+
scale = np.maximum(np.abs(best_guess), 1e-3)
1246+
candidates = best_guess * (1 + 0.3 * np.random.randn(100, len(best_guess)) / scale)
1247+
1248+
mu, sigma = gp.predict(candidates, return_std=True)
1249+
1250+
# Choose the most promising candidate (Expected Improvement heuristic)
1251+
acquisition = mu - 1.0 * sigma # exploit low mean, explore high uncertainty
1252+
next_p = candidates[np.argmin(acquisition)]
1253+
1254+
err, r2, guess = objective(next_p)
1255+
if np.isfinite(err):
1256+
X.append(next_p)
1257+
y.append(err)
1258+
1259+
if err < best_err:
1260+
best_guess, best_err, best_r2 = guess, err, r2
1261+
##
1262+
return best_guess
1263+
1264+
1265+
10421266
###############################################################################
10431267

10441268

BGlib/misc/bg_gui_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from BGlib.be.analysis.utils.sidpy_sho_fitter import SHOestimateGuess, SHOestimateGuess, SHO_fit_flattened
3131
from PyQt5.QtGui import QTextCursor
3232
from PyQt5.QtCore import QObject, pyqtSignal
33-
from BGlib.be.analysis.utils.be_loop import projectLoop, loop_fit_function, generate_guess, generate_shallow_guess, generate_deep_guess, calc_switching_coef_vec
33+
from BGlib.be.analysis.utils.be_loop import projectLoop, loop_fit_function, generate_guess, generate_shallow_guess, generate_deep_guess, generate_deepGP_guess, calc_switching_coef_vec
3434

3535

3636
class EmittingStream(QObject):
@@ -434,7 +434,8 @@ def on_do_loop_guess(self):
434434
d_guess_fn = {
435435
'Basic': generate_guess,
436436
'Shallow': generate_shallow_guess,
437-
'Deep': generate_deep_guess
437+
'Deep': generate_deep_guess,
438+
'DeepGP': generate_deepGP_guess
438439
}
439440
##
440441
self.do_loop_fit_button.setEnabled(False)

0 commit comments

Comments
 (0)