Skip to content

Commit f42d944

Browse files
committed
Improve GPR fitting and plot rendering
Expand GaussianProcess kernel bounds and increase optimizer restarts to reduce ConvergenceWarnings and improve global hyperparameter search (n_restarts_optimizer raised for CV and final fit). Use neg_mean_squared_error consistently (negated to positive MSE) and attach cross-validation scores to the returned model for later plotting. Also set figure DPI to 300 for higher-resolution output. Minor docs tweak: adjust figure subcaption and remove redundant comments in qs_class.qmd.
1 parent 49b8c03 commit f42d944

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

app/uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/qs_class.qmd

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,11 @@ _ = study.pod(poi_col="Length", threshold=18.0)
117117
# | label: fig-oop-results
118118
# | fig-cap: "Reliability Analysis Results"
119119
# | fig-subcap:
120+
# | - "Best Fitting Model (Statistics)"
120121
# | - "Signal Response Model (Physics)"
121122
# | - "Probability of Detection Curve (Reliability)"
122123
# | layout-ncol: 1
123124
# 2. Visualize Results
124-
# This automatically generates the Signal Model
125-
# and PoD Curve plots
126125
study.visualise()
127126
```
128127

src/digiqual/pod.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,22 @@ def fit_robust_mean_model(
4444
# 1. Evaluate Polynomials
4545
for d in degrees:
4646
model = make_pipeline(PolynomialFeatures(degree=d), LinearRegression())
47+
# Use neg_mean_squared_error; we take the negative to get positive MSE
4748
scores = cross_val_score(model, X_2d, y, cv=cv, scoring='neg_mean_squared_error')
4849
cv_scores[('Polynomial', d)] = -np.mean(scores)
4950

5051
# 2. Evaluate Kriging (Gaussian Process)
51-
kernel = C(1.0, (1e-3, 1e3)) * RBF(1.0, (1e-2, 1e2))
52+
# INCREASED BOUNDS: Constant value up to 1e6 and RBF length scale up to 1e5
53+
# This addresses the ConvergenceWarning
54+
kernel = C(1.0, (1e-5, 1e6)) * RBF(1.0, (1e-3, 1e5))
55+
5256
gpr = GaussianProcessRegressor(
53-
kernel=kernel, n_restarts_optimizer=5, alpha=np.var(y)*0.01, random_state=42
57+
kernel=kernel,
58+
n_restarts_optimizer=10, # Increased for better global search
59+
alpha=np.var(y) * 0.01,
60+
random_state=42
5461
)
62+
5563
gpr_scores = cross_val_score(gpr, X_2d, y, cv=cv, scoring='neg_mean_squared_error')
5664
cv_scores[('Kriging', None)] = -np.mean(gpr_scores)
5765

@@ -65,15 +73,17 @@ def fit_robust_mean_model(
6573
final_model.fit(X_2d, y)
6674
final_model.model_params_ = best_params
6775
else:
76+
# Re-initialize with the same wide-bound kernel
6877
final_model = GaussianProcessRegressor(
69-
kernel=kernel, n_restarts_optimizer=10, alpha=np.var(y)*0.01, random_state=42
78+
kernel=kernel,
79+
n_restarts_optimizer=15, # More restarts for the final fit
80+
alpha=np.var(y) * 0.01,
81+
random_state=42
7082
)
7183
final_model.fit(X_2d, y)
7284
final_model.model_params_ = final_model.kernel_
7385

7486
final_model.model_type_ = best_type
75-
76-
# NEW: Attach the scores so they can be plotted later!
7787
final_model.cv_scores_ = cv_scores
7888

7989
return final_model
@@ -117,7 +127,7 @@ def plot_model_selection(cv_scores: dict) -> Any:
117127

118128
# 3. Create the Figure and subplots (Bar chart on left, Table on right)
119129
fig, (ax_plot, ax_table) = plt.subplots(
120-
1, 2, figsize=(12, 5), gridspec_kw={'width_ratios': [2.5, 1]}
130+
1, 2, figsize=(12, 5),dpi=300, gridspec_kw={'width_ratios': [2.5, 1]}
121131
)
122132

123133
# --- Bar Chart (Left) ---

0 commit comments

Comments
 (0)