Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion samples/ml/ml_jobs/e2e_task_graph/src/modeling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union

import cloudpickle as cp
import data
Expand Down
71 changes: 22 additions & 49 deletions samples/ml/ml_jobs/e2e_task_graph/src/pipeline_dag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import io
import json
import os
import time
Expand All @@ -12,13 +11,13 @@
from snowflake.core.task.dagv1 import DAG, DAGOperation, DAGTask, DAGTaskBranch
from snowflake.ml.data import DatasetInfo
from snowflake.ml.dataset import load_dataset
from snowflake.ml.jobs import MLJob
from snowflake.snowpark import Session
from snowflake.ml.jobs import MLJobDefinition

import cli_utils
import data
import modeling
from constants import (DAG_STAGE, DATA_TABLE_NAME, DB_NAME, SCHEMA_NAME,
from constants import (COMPUTE_POOL,DAG_STAGE, DATA_TABLE_NAME, DB_NAME, JOB_STAGE, SCHEMA_NAME,
WAREHOUSE)

ARTIFACT_DIR = "run_artifacts"
Expand Down Expand Up @@ -180,56 +179,28 @@ def prepare_datasets(session: Session) -> str:
return json.dumps(dataset_info)


def train_model(session: Session) -> str:
# NOTE: Remove `target_instances=2` to run training on a single node
# See https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/distributed-ml-jobs
def train_model_sql(session: Session) -> str:
"""
DAG task to train a machine learning model.

This function is executed as part of the DAG workflow to train a model using the prepared datasets.
It retrieves dataset information from the previous task, trains the model, evaluates it on both
training and test sets, and saves the model to a stage for later use.
Returns the SQL statement to train a model using ML Jobs.

Args:
session (Session): Snowflake session object

Returns:
str: JSON string containing model path and evaluation metrics
str: SQL statement to train a model using ML Jobs
"""
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)

# Load the datasets
serialized = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
dataset_info = {
key: DatasetInfo(**obj_dict) for key, obj_dict in serialized.items()
}

# Train the model
model = modeling.train_model(session, dataset_info["train"])
if isinstance(model, MLJob):
model = model.result()

# Evaluate the model
train_metrics = modeling.evaluate_model(
session, model, dataset_info["train"], prefix="train"
)
test_metrics = modeling.evaluate_model(
session, model, dataset_info["test"], prefix="test"
)
metrics = {**train_metrics, **test_metrics}

# Save model to stage and return the metrics as a JSON string
model_pkl = cp.dumps(model)
model_path = os.path.join(config.artifact_dir, "model.pkl")
put_result = session.file.put_stream(
io.BytesIO(model_pkl), model_path, overwrite=True
job_definition = MLJobDefinition.register(
source = './src/',
entrypoint = 'train_model.py',
compute_pool = COMPUTE_POOL,
stage_name = JOB_STAGE,
imports=[("/Users/ajiang/PycharmProjects/snowml/snowflake/ml", "snowflake.ml")],
session = session,
target_instances = 2,
)

result_dict = {
"model_path": os.path.join(config.artifact_dir, put_result.target),
"metrics": metrics,
}
return json.dumps(result_dict)

return job_definition.to_sql(use_async=False)

def check_model_quality(session: Session) -> str:
"""
Expand All @@ -250,7 +221,6 @@ def check_model_quality(session: Session) -> str:

metrics = json.loads(ctx.get_predecessor_return_value("TRAIN_MODEL"))["metrics"]

# If model is good, promote model
threshold = config.metric_threshold
if metrics[config.metric_name] >= threshold:
return "promote_model"
Expand Down Expand Up @@ -341,7 +311,9 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
schedule=schedule,
use_func_return_value=True,
stage_location=DAG_STAGE,
packages=["snowflake-snowpark-python", "snowflake-ml-python<1.9.0", "xgboost"], # NOTE: Temporarily pinning to <1.9.0 due to compatibility issues
# Pin `xgboost` to reduce training/logging runtime drift (ML Jobs vs Task Graph runtime).
packages=["snowflake-snowpark-python", "xgboost==2.1.3", "snowflake-ml-python"],
imports=[("/Users/ajiang/PycharmProjects/snowml/snowflake/ml", "snowflake.ml")],
config={
"dataset_name": "mortgage_dataset",
"model_name": "mortgage_model",
Expand Down Expand Up @@ -369,10 +341,11 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
);
""",
)
cleanup_task = DAGTask("cleanup_task", definition=cleanup, is_finalizer=True)
train_model_task = DAGTask("TRAIN_MODEL", definition=train_model_sql(session))
_cleanup_task = DAGTask("cleanup_task", definition=cleanup, is_finalizer=True)

# Build the DAG
prepare_data >> train_model >> evaluate_model >> [promote_model_task, alert_task]
prepare_data >> train_model_task >> evaluate_model >> [promote_model_task, alert_task]

return dag

Expand Down
95 changes: 95 additions & 0 deletions samples/ml/ml_jobs/e2e_task_graph/src/train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from snowflake.ml.data import DataConnector, DatasetInfo, DataSource
from snowflake.core.task.context import TaskContext
from snowflake.snowpark import Session
from xgboost import XGBClassifier
import os
import json
import cloudpickle as cp
import io
from pipeline_dag import RunConfig
from modeling import evaluate_model

session = Session.builder.getOrCreate()

def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
"""
Train a model on the training dataset and evaluate it on the test dataset.

This function trains an XGBoost classifier on the provided training data. It extracts
features and labels from the input data, configures the model with predefined parameters,
and trains the model. This function is executed remotely on Snowpark Container Services.

Args:
session (Session): Snowflake session object
input_data (DataSource): Data source containing training data with features and labels

Returns:
XGBClassifier: Trained XGBoost classifier model
"""
input_data_df = DataConnector.from_sources(session, [input_data]).to_pandas()

assert isinstance(input_data, DatasetInfo), "Input data must be a DatasetInfo"
exclude_cols = input_data.exclude_cols
label_col = exclude_cols[0]

X_train = input_data_df.drop(exclude_cols, axis=1)
y_train = input_data_df[label_col].squeeze()

model_params = dict(
max_depth=50,
n_estimators=3,
learning_rate=0.75,
objective="binary:logistic",
booster="gbtree",
)

# Retrieve the number of nodes from environment variable
if int(os.environ.get("SNOWFLAKE_JOBS_COUNT", 1)) > 1:
# Distributed training - use ML Runtime distributor APIs
from snowflake.ml.modeling.distributors.xgboost.xgboost_estimator import (
XGBEstimator,
XGBScalingConfig,
)
estimator = XGBEstimator(
params=model_params,
scaling_config=XGBScalingConfig(num_workers=2),
)
else:
# Single node training - can use standard XGBClassifier
estimator = XGBClassifier(**model_params)

estimator.fit(X_train, y_train)

# Convert distributed estimator to standard XGBClassifier if needed
return getattr(estimator, '_sklearn_estimator', estimator)


if __name__ == "__main__":
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)

# Load the datasets
serialized = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
dataset_info = {
key: DatasetInfo(**obj_dict) for key, obj_dict in serialized.items()
}
artifact_dir = config.artifact_dir
model_obj = train_model(session, dataset_info["train"])
train_metrics = evaluate_model(
session, model_obj, dataset_info["train"], prefix="train"
)
test_metrics = evaluate_model(
session, model_obj, dataset_info["test"], prefix="test"
)
metrics = {**train_metrics, **test_metrics}

model_pkl = cp.dumps(model_obj)
model_path = os.path.join(config.artifact_dir, "model.pkl")
put_result = session.file.put_stream(
io.BytesIO(model_pkl), model_path, overwrite=True
)
result_dict = {
"model_path": os.path.join(config.artifact_dir, put_result.target),
"metrics": metrics,
}
ctx.set_return_value(json.dumps(result_dict))