Skip to content

Commit a786ac7

Browse files
add healthcare ML breast cancer classification guide
Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 55a1563 commit a786ac7

File tree

8 files changed

+1745
-0
lines changed

8 files changed

+1745
-0
lines changed
188 KB
Loading
141 KB
Loading
62.7 KB
Loading

site/sfguides/src/healthcare-ml-breast-cancer-classification/healthcare-ml-breast-cancer-classification.md

Lines changed: 498 additions & 0 deletions
Large diffs are not rendered by default.

site/sfguides/src/healthcare-ml-breast-cancer-classification/notebooks/0_start_here.ipynb

Lines changed: 755 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "intro",
6+
"metadata": {},
7+
"source": [
8+
"# Part 2: Snowflake Model Registry Deployment\n",
9+
"\n",
10+
"## Overview\n",
11+
"\n",
12+
"This notebook demonstrates **deploying an XGBoost model to Snowflake Model Registry** for production inference. You'll save training data to Snowflake tables and register your model for scalable, governed ML operations.\n",
13+
"\n",
14+
"### Prerequisites\n",
15+
"\n",
16+
"⚠️ **IMPORTANT**: Run `setup.sql` as ACCOUNTADMIN before starting this notebook.\n",
17+
"\n",
18+
"The setup script creates:\n",
19+
"- Role: `HEALTHCARE_ML_ROLE`\n",
20+
"- Database: `HEALTHCARE_ML`\n",
21+
"- Schema: `HEALTHCARE_ML.DIAGNOSTICS`\n",
22+
"- Warehouse: `HEALTHCARE_ML_WH`\n",
23+
"- Compute Pool: `HEALTHCARE_ML_CPU_POOL`\n",
24+
"\n",
25+
"### What You'll Learn\n",
26+
"\n",
27+
"1. **Persist data** to Snowflake tables\n",
28+
"2. **Register models** in Snowflake Model Registry\n",
29+
"3. **Run inference** using registered models\n",
30+
"4. **Track metadata** (metrics, versions, comments)\n",
31+
"\n",
32+
"> **Note**: This notebook requires Container Runtime and must be run from **Snowsight**."
33+
]
34+
},
35+
{
36+
"cell_type": "markdown",
37+
"id": "load_intro",
38+
"metadata": {},
39+
"source": [
40+
"## Step 1: Load Artifacts from Part 1\n",
41+
"\n",
42+
"Load the trained model and data from `/tmp` that were saved in Part 1."
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": null,
48+
"id": "load_artifacts",
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"import pickle\n",
53+
"import pandas as pd\n",
54+
"from snowflake.snowpark.context import get_active_session\n",
55+
"\n",
56+
"# Load artifacts from Part 1\n",
57+
"with open('/tmp/breast_cancer_artifacts.pkl', 'rb') as f:\n",
58+
" artifacts = pickle.load(f)\n",
59+
"\n",
60+
"best_model = artifacts['best_model']\n",
61+
"X_train = artifacts['X_train']\n",
62+
"X_test = artifacts['X_test']\n",
63+
"y_train = artifacts['y_train']\n",
64+
"y_test = artifacts['y_test']\n",
65+
"test_accuracy = artifacts['test_accuracy']\n",
66+
"test_f1 = artifacts['test_f1']\n",
67+
"roc_auc = artifacts['roc_auc']\n",
68+
"pr_auc = artifacts['pr_auc']\n",
69+
"cv_results = artifacts['cv_results']\n",
70+
"feature_names = artifacts['feature_names']\n",
71+
"\n",
72+
"print(\"=\" * 60)\n",
73+
"print(\"✅ ARTIFACTS LOADED FROM /tmp\")\n",
74+
"print(\"=\" * 60)\n",
75+
"print(f\"Model: XGBoost ({best_model.n_estimators} estimators)\")\n",
76+
"print(f\"Training data: {X_train.shape[0]} samples × {X_train.shape[1]} features\")\n",
77+
"print(f\"Test data: {X_test.shape[0]} samples\")\n",
78+
"print(f\"Test Accuracy: {test_accuracy:.4f}\")\n",
79+
"print(f\"ROC AUC: {roc_auc:.4f}\")\n",
80+
"\n",
81+
"# Connect to Snowflake\n",
82+
"session = get_active_session()\n",
83+
"session.sql(\"\"\"\n",
84+
" ALTER SESSION SET query_tag = '{\"origin\":\"sf_sit-is\",\"name\":\"healthcare_ml_classification\",\"version\":{\"major\":1,\"minor\":0},\"attributes\":{\"is_quickstart\":1,\"source\":\"notebook\"}}'\n",
85+
"\"\"\").collect()\n",
86+
"print(f\"\\n✅ Connected to Snowflake: {session.get_current_account()}\")"
87+
]
88+
},
89+
{
90+
"cell_type": "markdown",
91+
"id": "a18fec30",
92+
"metadata": {},
93+
"source": [
94+
"## Step 1: Environment Setup\n",
95+
"\n",
96+
"### Import Libraries\n",
97+
"\n",
98+
"We'll use a combination of data science and Snowflake-specific libraries:\n",
99+
"\n",
100+
"| Library | Purpose |\n",
101+
"|---------|---------|\n",
102+
"| `snowflake.snowpark` | Snowflake session management |\n",
103+
"| `pandas`, `numpy` | Data manipulation and numerical operations |\n",
104+
"| `matplotlib`, `seaborn` | Statistical visualizations |\n",
105+
"| `sklearn` | ML utilities, metrics, and baseline models |\n",
106+
"| `xgboost` | Gradient boosting implementation |\n",
107+
"\n",
108+
"> **Note**: All libraries are pre-installed in Container Runtime - no `!pip install` or EAIs needed."
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": null,
114+
"id": "9ad41959",
115+
"metadata": {},
116+
"outputs": [],
117+
"source": [
118+
"from snowflake.ml.registry import Registry\n",
119+
"from snowflake.ml.model import task\n",
120+
"\n",
121+
"DATABASE = \"HEALTHCARE_ML\"\n",
122+
"SCHEMA = \"DIAGNOSTICS\"\n",
123+
"\n",
124+
"session.use_database(DATABASE)\n",
125+
"session.use_schema(SCHEMA)\n",
126+
"\n",
127+
"registry = Registry(session=session)\n",
128+
"\n",
129+
"MODEL_NAME = \"BREAST_CANCER_CLASSIFIER\"\n",
130+
"\n",
131+
"print(\"Logging model to Snowflake Model Registry...\")\n",
132+
"mv = registry.log_model(\n",
133+
" best_model,\n",
134+
" model_name=MODEL_NAME,\n",
135+
" sample_input_data=X_train.head(),\n",
136+
" target_platforms=[\"WAREHOUSE\"],\n",
137+
" task=task.Task.TABULAR_BINARY_CLASSIFICATION,\n",
138+
" options={'relax_version': False},\n",
139+
" metrics={\n",
140+
" \"test_accuracy\": float(test_accuracy),\n",
141+
" \"test_f1_score\": float(test_f1),\n",
142+
" \"roc_auc\": float(roc_auc),\n",
143+
" \"cv_accuracy_mean\": float(cv_results['XGBoost'].mean()),\n",
144+
" \"cv_accuracy_std\": float(cv_results['XGBoost'].std()),\n",
145+
" \"n_estimators\": 100,\n",
146+
" \"max_depth\": 6,\n",
147+
" \"learning_rate\": 0.1\n",
148+
" },\n",
149+
" comment=\"XGBoost classifier for breast cancer diagnosis. Trained on Wisconsin Diagnostic dataset (569 samples, 30 features). Cross-validated.\"\n",
150+
")\n",
151+
"\n",
152+
"print(\"=\" * 60)\n",
153+
"print(\"MODEL REGISTRY - SUCCESS\")\n",
154+
"print(\"=\" * 60)\n",
155+
"print(f\"Model Name: {MODEL_NAME}\")\n",
156+
"print(f\"Version: {mv.version_name}\")\n",
157+
"print(f\"Test Accuracy: {test_accuracy:.4f}\")\n",
158+
"print(f\"ROC AUC: {roc_auc:.4f}\")"
159+
]
160+
},
161+
{
162+
"cell_type": "markdown",
163+
"id": "3e428a2e",
164+
"metadata": {},
165+
"source": [
166+
"## Step 3: Model Inference\n",
167+
"\n",
168+
"### Running Predictions with the Registered Model\n",
169+
"\n",
170+
"Once deployed to the Model Registry, inference can be performed via:\n",
171+
"\n",
172+
"| Method | Use Case | Scalability |\n",
173+
"|--------|----------|-------------|\n",
174+
"| `mv.run()` (Python) | Notebooks, scripts | Batch processing |\n",
175+
"| `MODEL!PREDICT()` (SQL) | Dashboards, ETL pipelines | Warehouse-scale |\n",
176+
"\n",
177+
"The model executes **within Snowflake** - no data leaves the platform, maintaining security and governance."
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": null,
183+
"id": "904c5d8e",
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"print(f\"Running inference using model: {mv.model_name} (version: {mv.version_name})\")\n",
188+
"predictions = mv.run(X_test, function_name=\"predict\")\n",
189+
"print(f\"Prediction columns: {predictions.columns.tolist()}\")\n",
190+
"pred_col = predictions.columns[-1]\n",
191+
"predictions[[pred_col]].rename(columns={pred_col: \"PREDICTION\"}).head(10)"
192+
]
193+
},
194+
{
195+
"cell_type": "markdown",
196+
"id": "fe14c8d1",
197+
"metadata": {},
198+
"source": [
199+
"## Step 4: Explore Registered Model\n",
200+
"\n",
201+
"The Model Registry stores model artifacts along with metadata. Let's inspect:\n",
202+
"- **Available methods**: predict, predict_proba\n",
203+
"- **Logged metrics**: accuracy, AUC, hyperparameters\n",
204+
"\n",
205+
"> **Tip**: View your model in Snowsight under **AI & ML > Models** for a visual interface."
206+
]
207+
},
208+
{
209+
"cell_type": "code",
210+
"execution_count": null,
211+
"id": "28216849",
212+
"metadata": {},
213+
"outputs": [],
214+
"source": [
215+
"print(\"Available methods:\")\n",
216+
"for func in mv.show_functions():\n",
217+
" print(f\" - {func['name']}\")\n",
218+
"\n",
219+
"print(f\"\\nModel metrics:\")\n",
220+
"mv.show_metrics()"
221+
]
222+
},
223+
{
224+
"cell_type": "markdown",
225+
"id": "1483b6cb",
226+
"metadata": {},
227+
"source": [
228+
"## Step 5: (Optional) Persist Data to Snowflake\n",
229+
"\n",
230+
"**Data Persistence Options:**\n",
231+
"\n",
232+
"| Method | Use Case | Durability |\n",
233+
"|--------|----------|------------|\n",
234+
"| Snowflake Table | Structured data, SQL queries | Permanent |\n",
235+
"| Snowflake Stage | Files, artifacts | Permanent |\n",
236+
"| Notebook CWD | Temporary files | Session only ⚠️ |\n",
237+
"\n",
238+
"> **Warning**: The notebook working directory (`/home/udf/`) does not persist between sessions. Always save important data to tables or stages."
239+
]
240+
},
241+
{
242+
"cell_type": "code",
243+
"execution_count": null,
244+
"id": "7cf4e1ff",
245+
"metadata": {},
246+
"outputs": [],
247+
"source": [
248+
"# OPTIONAL: Save training data to Snowflake\n",
249+
"# Uncomment and update the database/schema names to match your environment\n",
250+
"\n",
251+
"# train_df = X_train.copy()\n",
252+
"# train_df[\"DIAGNOSIS\"] = y_train.values\n",
253+
"# \n",
254+
"# snowpark_df = session.create_dataframe(train_df)\n",
255+
"# snowpark_df.write.mode(\"overwrite\").save_as_table(\"HEALTHCARE_ML.DIAGNOSTICS.BREAST_CANCER_TRAINING_DATA\")\n",
256+
"# \n",
257+
"# print(\"Training data saved to Snowflake table\")"
258+
]
259+
},
260+
{
261+
"cell_type": "markdown",
262+
"id": "7550b87a",
263+
"metadata": {},
264+
"source": [
265+
"## Summary and Key Takeaways\n",
266+
"\n",
267+
"### What We Accomplished\n",
268+
"\n",
269+
"| Step | Technique | Outcome |\n",
270+
"|------|-----------|---------|\n",
271+
"| Data Exploration | Statistical analysis + visualizations | Understood feature distributions and class balance |\n",
272+
"| Feature Engineering | StandardScaler | Normalized features for fair model comparison |\n",
273+
"| Model Selection | 5-Fold Stratified CV | Compared 3 algorithms, selected XGBoost |\n",
274+
"| Evaluation | Multiple metrics + visualizations | Validated model with ~97% accuracy, 0.99 AUC |\n",
275+
"| Deployment | Snowflake Model Registry | Production-ready model with versioning |\n",
276+
"\n",
277+
"### Performance Summary\n",
278+
"\n",
279+
"| Metric | Value | Interpretation |\n",
280+
"|--------|-------|----------------|\n",
281+
"| Test Accuracy | ~97% | Correct predictions overall |\n",
282+
"| ROC AUC | ~0.99 | Excellent discrimination |\n",
283+
"| Malignant Recall | ~95%+ | Catches most cancers |\n",
284+
"| Benign Precision | ~98%+ | Few false alarms |\n",
285+
"\n",
286+
"### Production Usage\n",
287+
"\n",
288+
"```sql\n",
289+
"-- SQL Inference\n",
290+
"SELECT BREAST_CANCER_CLASSIFIER!PREDICT(*) FROM your_patient_data;\n",
291+
"\n",
292+
"-- Python Inference\n",
293+
"model_version = registry.get_model(\"BREAST_CANCER_CLASSIFIER\").version(\"V1\")\n",
294+
"predictions = model_version.run(new_data, function_name=\"predict\")\n",
295+
"```\n",
296+
"\n",
297+
"### Next Steps\n",
298+
"\n",
299+
"1. **Hyperparameter Tuning**: Use GridSearchCV or Optuna for optimization\n",
300+
"2. **Feature Selection**: Reduce to top 10-15 features for efficiency\n",
301+
"3. **Model Monitoring**: Track prediction drift in production\n",
302+
"4. **A/B Testing**: Compare model versions on live data\n",
303+
"\n",
304+
"> **Resources**: [Snowflake ML Documentation](https://docs.snowflake.com/en/developer-guide/snowflake-ml/overview) | [XGBoost Documentation](https://xgboost.readthedocs.io/)"
305+
]
306+
}
307+
],
308+
"metadata": {
309+
"kernelspec": {
310+
"display_name": "Python 3",
311+
"language": "python",
312+
"name": "python3"
313+
},
314+
"language_info": {
315+
"name": "python",
316+
"version": "3.12.0"
317+
}
318+
},
319+
"nbformat": 4,
320+
"nbformat_minor": 5
321+
}

0 commit comments

Comments
 (0)