Skip to content

Commit fa6e40b

Browse files
authored
Pre-v0.11.0 PR - Add Kriging Model & update diagnostic testing (#10)
* Add Kriging mean model and update PoD pipeline Introduce Gaussian Process (Kriging) as an alternative to polynomial mean models and wire it through the PoD pipeline and UI. Changes in src/digiqual/pod.py: implement fit_robust_mean_model to evaluate both polynomial degrees and a GaussianProcessRegressor via cross-validation, expose model_type_ and model_params_ on the returned model, and update bootstrap_pod_ci to accept model_type/model_params and handle Kriging (disable optimizer during bootstrap). Changes in src/digiqual/core.py: adjust logging/messages and pass the new model_type/model_params into bootstrap_pod_ci. Changes in app/app.py: update UI metrics to display the chosen Mean Model and Error Distribution string. Tests in tests/test_pod.py updated to reflect dynamic model selection, added a Kriging-specific bootstrap test, and adjusted assertions/fixtures accordingly. Overall this enables dynamic mean-model selection (Polynomial vs Kriging) and keeps bootstrap and UI behavior consistent with the new models. * Add model-selection plot, refinement UI, and diagnostics tweaks Add targeted-sampling UI and download, introduce model-selection visualization, improve bootstrap diagnostics, and add example scripts. Key changes: - app/app.py: add handler for refinement button to call Study.refine(), store generated samples and provide a CSV download UI; add plot_model_selection render output and reorganize results layout. - app/run_app.py: fix JS filename map key to match new download id ('download_new_samples'). - src/digiqual/pod.py: attach cv_scores_ to fitted mean model, remove plot_cv flag, and add plot_model_selection() to render a normalized bias-variance bar chart + MSE table. - src/digiqual/core.py: save/show the new model_selection plot when available in SimulationStudy outputs. - src/digiqual/diagnostics.py: revise bootstrap convergence routine (use degree-2 poly, use relative std dev / CV, average+max thresholds) and update sample_sufficiency report to include thresholds and both avg/max CV metrics. - scripts/kriging_run.py: new example script generating non-linear (sigmoid) data and running a PoD analysis using SimulationStudy. - scripts/make_fake_data.py: modify fake-data generators to create deliberate gap/heteroskedastic examples and adjust filenames and sizes for testing; replace normal noise with gamma noise. - tests/test_pod.py: update imports and add a test that plot_model_selection produces a figure and that fitted model has cv_scores_. Rationale: provide a visual model-selection tool for bias-variance tradeoff, expose refinement/downloading of newly generated targeted samples in the UI, and make bootstrap convergence checks more robust to heteroskedasticity and tail behavior. Also add example scripts for Kriging and improved synthetic data for diagnostics testing.
1 parent 3ec7c97 commit fa6e40b

File tree

8 files changed

+382
-237
lines changed

8 files changed

+382
-237
lines changed

app/app.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,24 @@ def initialize_column_selectors():
477477
ui.update_selectize("input_cols", choices=cols, selected=default_inputs)
478478
ui.update_selectize("outcome_col", choices=cols, selected=default_outcome)
479479

480+
@reactive.effect
481+
@reactive.event(input.btn_refine)
482+
def handle_refinement():
483+
study = current_study()
484+
if study is None:
485+
return
486+
487+
try:
488+
# We assume your SimulationStudy has a refine method
489+
# that targets the 'Length' gaps we discussed earlier
490+
n_to_gen = input.n_new_samples()
491+
refined_df = study.refine(n_points=n_to_gen) # Or your specific generation logic
492+
493+
new_samples.set(refined_df)
494+
ui.notification_show(f"Generated {n_to_gen} targeted samples.", type="message")
495+
except Exception as e:
496+
ui.notification_show(f"Refinement failed: {e}", type="error")
497+
480498
@render.ui
481499
def selection_error_display():
482500
"""Displays a permanent red error if selections conflict."""
@@ -605,6 +623,29 @@ def remediation_ui():
605623
)
606624

607625

626+
@render.ui
627+
def download_new_samples_ui():
628+
# Only show the button if new_samples has been populated
629+
if new_samples() is None:
630+
return None
631+
632+
return ui.div(
633+
ui.hr(),
634+
ui.p("Success! Download your targeted samples below:", class_="small"),
635+
ui.download_button(
636+
"download_new_samples",
637+
"Download Refined CSV",
638+
class_="btn-success w-100",
639+
icon=icon_svg("download")
640+
)
641+
)
642+
643+
@render.download(filename="remediation_samples.csv")
644+
def download_new_samples():
645+
df = new_samples()
646+
if df is not None:
647+
yield df.to_csv(index=False)
648+
608649

609650
#### Server - PoD Generation (Tab 4) ####
610651

@@ -709,19 +750,25 @@ def compute_pod_analysis():
709750
val = results["a90_95"]
710751
a9095_str = f"{val:.3f}" if not np.isnan(val) else "Not Reached"
711752

753+
# 3. Format the Mean Model string based on the new architecture
754+
mean_model = results["mean_model"]
755+
if mean_model.model_type_ == 'Polynomial':
756+
model_str = f"Polynomial (Degree {mean_model.model_params_})"
757+
else:
758+
model_str = "Kriging (Gaussian Process)"
712759

713-
714-
# 3. Create Metrics Dictionary for the UI
760+
# 4. Create Metrics Dictionary for the UI
715761
metrics = {
716762
"Parameter of Interest": results["poi_col"],
717763
"Threshold": results["threshold"],
718764
"a90/95": a9095_str,
719-
"Model Degree": results["mean_model"].best_degree_,
765+
"Mean Model": model_str,
720766
"Smoothing Bandwidth": f"{results['bandwidth']:.4f}",
767+
"Error Distribution": results["dist_info"][0].capitalize()
721768
}
722769
pod_metrics.set(metrics)
723770

724-
# 4. Prepare Data for Download
771+
# 5. Prepare Data for Download
725772
export_df = pd.DataFrame({
726773
"x_defect_size": results["X_eval"],
727774
"pod_mean": results["curves"]["pod"],
@@ -730,25 +777,32 @@ def compute_pod_analysis():
730777
})
731778
pod_export_data.set(export_df)
732779

733-
# 5. Generate Plots (Visualise draws them internally)
780+
# 6. Generate Plots (Visualise draws them internally)
734781
study.visualise(show=False)
735782
plot_trigger.set(plot_trigger() + 1)
736783

737784
except Exception as e:
738785
ui.notification_show(f"Analysis Failed: {str(e)}", type="error")
739786

740787

741-
# --- RESULTS DISPLAY ---
788+
# --- RESULTS DISPLAY ---
742789
@render.ui
743790
def pod_results_container():
744791
"""
745-
Renders the side-by-side plots and the metrics table.
792+
Renders the model selection plot, side-by-side analysis plots, and the metrics table.
746793
"""
747794
if pod_metrics() is None:
748795
return ui.div()
749796

750797
return ui.div(
751-
# Row 1: Plots
798+
# Row 1: Model Selection Plot (Full Width)
799+
ui.card(
800+
ui.card_header("Model Selection (Bias-Variance Tradeoff)"),
801+
ui.output_plot("plot_model_selection", height="400px"),
802+
full_screen=True,
803+
class_="mb-3"
804+
),
805+
# Row 2: Signal Model and PoD Plots
752806
ui.layout_columns(
753807
ui.card(
754808
ui.card_header("Signal Model Fit"),
@@ -760,9 +814,10 @@ def pod_results_container():
760814
ui.output_plot("plot_curve"),
761815
full_screen=True
762816
),
763-
col_widths=[6, 6]
817+
col_widths=[6, 6],
818+
class_="mb-3"
764819
),
765-
# Row 2: Table and Download Actions
820+
# Row 3: Table and Download Actions
766821
ui.layout_columns(
767822
ui.card(
768823
ui.card_header("Key Reliability Metrics"),
@@ -780,6 +835,14 @@ def pod_results_container():
780835
)
781836
)
782837

838+
@render.plot
839+
def plot_model_selection():
840+
_ = plot_trigger() # Dependency on button click
841+
study = current_study()
842+
if study and "model_selection" in study.plots:
843+
return study.plots["model_selection"]
844+
return None
845+
783846
@render.plot
784847
def plot_signal():
785848
_ = plot_trigger() # Dependency on button click

app/run_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def inject_js(window):
7878
// Map Button IDs to Filenames
7979
const filenameMap = {
8080
'download_lhs': 'experimental_design.csv',
81-
'download_new_samples_csv': 'refinement_samples.csv',
81+
'download_new_samples': 'refinement_samples.csv',
8282
'download_pod_results': 'pod_analysis_results.csv'
8383
};
8484

scripts/kriging_run.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pandas as pd
3+
from digiqual.core import SimulationStudy
4+
5+
print("Generating synthetic non-linear data...")
6+
# 1. Generate Non-linear Data (Sigmoid Curve)
7+
# This shape is difficult for polynomials but perfect for Kriging.
8+
np.random.seed(42)
9+
flaw_sizes = np.linspace(0.1, 10.0, 150)
10+
11+
# Sigmoid function: plateaus at the top and bottom
12+
true_responses = 20 / (1 + np.exp(-1.5 * (flaw_sizes - 5)))
13+
# Add noise that scales slightly with the flaw size
14+
noise = np.random.normal(0, 1.0 + 0.1 * flaw_sizes, size=len(flaw_sizes))
15+
responses = true_responses + noise
16+
17+
df = pd.DataFrame({
18+
'Flaw_Size': flaw_sizes,
19+
'Response': responses
20+
})
21+
22+
# 3. Initialize the Study
23+
print("Initializing SimulationStudy...")
24+
study = SimulationStudy(input_cols=['Flaw_Size'], outcome_col='Response')
25+
study.add_data(df)
26+
study.diagnose()
27+
28+
# 4. Run the PoD Analysis
29+
# We use a threshold that intersects the middle of our S-Curve (e.g., 10.0)
30+
# Using 100 bootstrap iterations so it runs relatively quickly for testing
31+
print("\n--- Running PoD Analysis ---")
32+
results = study.pod(poi_col='Flaw_Size', threshold=10.0, n_boot=100)
33+
34+
# 5. Show the Final Visualizations
35+
print("\n--- Generating Visualizations ---")
36+
study.visualise()

scripts/make_fake_data.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,42 @@
11
import pandas as pd
22
import numpy as np
33

4-
def generate_fake_data(filename="initial_data.csv", n=50):
5-
"""Generates a small dataset that might FAIL diagnostics (for testing the 'Fix' loop)."""
6-
np.random.seed(42)
7-
8-
# 1. Generate Inputs (Small N = likely gaps)
4+
def generate_fake_data(filename="app/initial_data.csv", n=25):
5+
"""Fails due to massive Gaps and Skewed Heteroskedasticity."""
6+
# 1. Deliberate Gap (0-2 and 8-10)
7+
lengths = np.concatenate([np.random.uniform(0, 2, 12), np.random.uniform(8, 10, 13)])
98
df = pd.DataFrame({
10-
'Length': np.random.uniform(0, 10, n),
9+
'Length': lengths,
1110
'Angle': np.random.uniform(-45, 45, n)
1211
})
1312

14-
# 2. Physics & Noise
15-
base_signal = (df['Length'] * 2.0) - (0.1 * df['Angle'].abs())
16-
noise_scale = 0.5 + (0.1 * df['Length'])
17-
noise = np.random.normal(loc=0, scale=noise_scale, size=n)
13+
# 2. Monotonic Physics + Skewed Gamma Noise
14+
# As Length increases, the 'scale' of the Gamma noise increases (Heteroskedasticity)
15+
base_signal = 10.0 + 1.5 * df['Length'] + 0.2 * (df['Length']**2)
1816

19-
df['Signal'] = np.abs(base_signal + noise)
17+
# Non-normal noise: Gamma distribution is always positive and skewed
18+
noise_scale = 0.5 + (0.8 * df['Length'])
19+
noise = np.random.gamma(shape=2.0, scale=noise_scale, size=n)
2020

21+
df['Signal'] = base_signal + noise
2122
df.to_csv(filename, index=False)
22-
print(f"✅ Created '{filename}' with {n} rows (likely to have issues).")
23-
23+
print(f"✅ Created '{filename}' (N={n}). Should fail Gap and Bootstrap.")
2424

25-
def updated_data(filename="sufficient_data.csv", n=200):
26-
"""Generates a large dataset that should PASS all diagnostics."""
27-
np.random.seed(999) # Different seed
28-
29-
# 1. Generate Inputs (Large N = good coverage)
25+
def updated_data(filename="app/sufficient_data.csv", n=1500):
26+
"""Passes because high N overcomes the skewed noise."""
3027
df = pd.DataFrame({
3128
'Length': np.random.uniform(0, 10, n),
3229
'Angle': np.random.uniform(-45, 45, n)
3330
})
3431

35-
# 2. Physics & Noise
36-
base_signal = (df['Length'] * 2.0) - (0.1 * df['Angle'].abs())
37-
noise_scale = 0.5 + (0.1 * df['Length'])
38-
noise = np.random.normal(loc=0, scale=noise_scale, size=n)
39-
40-
df['Signal'] = np.abs(base_signal + noise)
32+
base_signal = 10.0 + 1.5 * df['Length'] + 0.2 * (df['Length']**2)
33+
noise_scale = 0.5 + (0.8 * df['Length'])
34+
noise = np.random.gamma(shape=2.0, scale=noise_scale, size=n)
4135

36+
df['Signal'] = base_signal + noise
4237
df.to_csv(filename, index=False)
43-
print(f"✅ Created '{filename}' with {n} rows (should pass checks).")
38+
print(f"✅ Created '{filename}' (N={n}). Should pass all tests.")
4439

4540
if __name__ == "__main__":
46-
# You can comment out the one you don't want, or run both
4741
generate_fake_data()
4842
updated_data()

src/digiqual/core.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def optimise(
267267
self.data = pd.DataFrame() # Clear old state to avoid duplication
268268
self.add_data(final_data)
269269

270-
#### PoD Analysis ####
270+
#### PoD Analysis ####
271271
def pod(
272272
self,
273273
poi_col: str,
@@ -308,10 +308,13 @@ def pod(
308308
X = self.clean_data[poi_col].values
309309
y = self.clean_data[self.outcome].values
310310

311-
# 2. Fit Mean Model (Robust Polynomial)
311+
# 2. Fit Mean Model (Robust Regression)
312312
print("1. Selecting Mean Model (Cross-Validation)...")
313313
mean_model = pod.fit_robust_mean_model(X, y)
314-
print(f" -> Selected Polynomial Degree: {mean_model.best_degree_}")
314+
if mean_model.model_type_ == 'Polynomial':
315+
print(f"-> Selected Model: Polynomial (Degree {mean_model.model_params_})")
316+
else:
317+
print("-> Selected Model: Kriging (Gaussian Process)")
315318

316319
# 3. Fit Variance Model & Generate Grid
317320
print("2. Fitting Variance Model (Kernel Smoothing)...")
@@ -335,7 +338,7 @@ def pod(
335338
print(f"5. Running Bootstrap ({n_boot} iterations)...")
336339
lower_ci, upper_ci = pod.bootstrap_pod_ci(
337340
X, y, X_eval, threshold,
338-
mean_model.best_degree_, bandwidth, (dist_name, dist_params),
341+
mean_model.model_type_, mean_model.model_params_, bandwidth, (dist_name, dist_params),
339342
n_boot=n_boot
340343
)
341344

@@ -397,6 +400,10 @@ def visualise(self, show: bool = True, save_path: str = None) -> None:
397400
res["X"], res["residuals"], res["X_eval"], res["bandwidth"]
398401
)
399402

403+
# 0. Model Selection Plot (NEW)
404+
if hasattr(res["mean_model"], "cv_scores_"):
405+
self.plots["model_selection"] = pod.plot_model_selection(res["mean_model"].cv_scores_)
406+
400407
# 1. Signal Model Plot
401408
self.plots["signal_model"] = plot_signal_model(
402409
X=res["X"],
@@ -420,6 +427,8 @@ def visualise(self, show: bool = True, save_path: str = None) -> None:
420427

421428
# Handle Saving
422429
if save_path:
430+
if "model_selection" in self.plots:
431+
self.plots["model_selection"].savefig(f"{save_path}_model_selection.png")
423432
self.plots["signal_model"].get_figure().savefig(f"{save_path}_signal.png")
424433
self.plots["pod_curve"].get_figure().savefig(f"{save_path}_pod.png")
425434
print(f"Plots saved to {save_path}_*.png")

0 commit comments

Comments
 (0)