Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions samples/ml/ml_jobs/e2e_task_graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,18 @@ This visual interface makes it easy to:

### Model Training on SPCS using ML Jobs

The `train_model` function uses the `@remote` decorator to run multi-node training on Snowpark Container Services:
The Task Graph runs model training on SPCS via a Snowflake ML Job entrypoint (`src/train_model.py`), which calls `modeling.train_model(...)` and returns metrics and a serialized model path back to the DAG.

```python
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
# Training logic runs on distributed compute
train_job_definition = MLJobDefinition.register(
source="./src",
entrypoint="train_model.py",
compute_pool=COMPUTE_POOL,
stage_name=JOB_STAGE,
session=session,
target_instances=2,
)
train_model_task = DAGTask("TRAIN_MODEL", definition=train_job_definition)
```

### Conditional Model Promotion
Expand Down
16 changes: 1 addition & 15 deletions samples/ml/ml_jobs/e2e_task_graph/src/modeling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union

import cloudpickle as cp
import data
Expand Down Expand Up @@ -144,23 +143,10 @@ def prepare_datasets(
return (ds, train_ds, test_ds)


# NOTE: Remove `target_instances=2` to run training on a single node
# See https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/distributed-ml-jobs
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
Comment on lines -147 to -149
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the main points of this sample is to demonstrate how easy it is to convert a local pipeline to pushing certain steps down into ML Jobs. Needing to write a separate script file which we submit_file() just for this conversion severely weakens this story. Why can't we just keep using a @remote() decorated function? @remote(...) should convert the function into an MLJobDefinition which we can directly use in pipeline_dag without needing an explicit MLJobDefinition.register() call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is currently @remote does not create job definition and it creates a job directly. Currently, we only merged the PR for phase one and phase 2 is in review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hold off on merging this until @remote is ready then

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the @remote change is now available, can we now call this as an ML Job directly from pipeline_dag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am little confused here. Do you mean we create a job inside the task directly?

def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
"""
Train a model on the training dataset.

This function trains an XGBoost classifier on the provided training data. It extracts
features and labels from the input data, configures the model with predefined parameters,
and trains the model. This function is executed remotely on Snowpark Container Services.

Args:
session (Session): Snowflake session object
input_data (DataSource): Data source containing training data with features and labels

Returns:
XGBClassifier: Trained XGBoost classifier model
"""
input_data_df = DataConnector.from_sources(session, [input_data]).to_pandas()

Expand Down Expand Up @@ -197,7 +183,7 @@ def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
estimator.fit(X_train, y_train)

# Convert distributed estimator to standard XGBClassifier if needed
return getattr(estimator, '_sklearn_estimator', estimator)
return getattr(estimator, "_sklearn_estimator", estimator) or estimator


def evaluate_model(
Expand Down
71 changes: 13 additions & 58 deletions samples/ml/ml_jobs/e2e_task_graph/src/pipeline_dag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import io
import json
import os
import time
Expand All @@ -12,13 +11,13 @@
from snowflake.core.task.dagv1 import DAG, DAGOperation, DAGTask, DAGTaskBranch
from snowflake.ml.data import DatasetInfo
from snowflake.ml.dataset import load_dataset
from snowflake.ml.jobs import MLJob
from snowflake.snowpark import Session
from snowflake.ml.jobs import MLJobDefinition

import cli_utils
import data
import modeling
from constants import (DAG_STAGE, DATA_TABLE_NAME, DB_NAME, SCHEMA_NAME,
from constants import (COMPUTE_POOL, DAG_STAGE, DATA_TABLE_NAME, DB_NAME, JOB_STAGE, SCHEMA_NAME,
WAREHOUSE)

ARTIFACT_DIR = "run_artifacts"
Expand Down Expand Up @@ -179,58 +178,6 @@ def prepare_datasets(session: Session) -> str:
}
return json.dumps(dataset_info)


def train_model(session: Session) -> str:
"""
DAG task to train a machine learning model.

This function is executed as part of the DAG workflow to train a model using the prepared datasets.
It retrieves dataset information from the previous task, trains the model, evaluates it on both
training and test sets, and saves the model to a stage for later use.

Args:
session (Session): Snowflake session object

Returns:
str: JSON string containing model path and evaluation metrics
"""
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)

# Load the datasets
serialized = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
dataset_info = {
key: DatasetInfo(**obj_dict) for key, obj_dict in serialized.items()
}

# Train the model
model = modeling.train_model(session, dataset_info["train"])
if isinstance(model, MLJob):
model = model.result()

# Evaluate the model
train_metrics = modeling.evaluate_model(
session, model, dataset_info["train"], prefix="train"
)
test_metrics = modeling.evaluate_model(
session, model, dataset_info["test"], prefix="test"
)
metrics = {**train_metrics, **test_metrics}

# Save model to stage and return the metrics as a JSON string
model_pkl = cp.dumps(model)
model_path = os.path.join(config.artifact_dir, "model.pkl")
put_result = session.file.put_stream(
io.BytesIO(model_pkl), model_path, overwrite=True
)

result_dict = {
"model_path": os.path.join(config.artifact_dir, put_result.target),
"metrics": metrics,
}
return json.dumps(result_dict)


def check_model_quality(session: Session) -> str:
"""
DAG task to check model quality and determine next action.
Expand All @@ -250,7 +197,6 @@ def check_model_quality(session: Session) -> str:

metrics = json.loads(ctx.get_predecessor_return_value("TRAIN_MODEL"))["metrics"]

# If model is good, promote model
threshold = config.metric_threshold
if metrics[config.metric_name] >= threshold:
return "promote_model"
Expand Down Expand Up @@ -341,7 +287,7 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
schedule=schedule,
use_func_return_value=True,
stage_location=DAG_STAGE,
packages=["snowflake-snowpark-python", "snowflake-ml-python<1.9.0", "xgboost"], # NOTE: Temporarily pinning to <1.9.0 due to compatibility issues
packages=["snowflake-snowpark-python", "snowflake-ml-python", "xgboost"],
config={
"dataset_name": "mortgage_dataset",
"model_name": "mortgage_model",
Expand All @@ -352,6 +298,15 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
) as dag:
# Need to wrap first function in a DAGTask to make >> operator work properly
prepare_data = DAGTask("prepare_data", definition=prepare_datasets)
train_job_definition = MLJobDefinition.register(
source="./src",
entrypoint="train_model.py",
compute_pool=COMPUTE_POOL,
stage_name=JOB_STAGE,
session=session,
target_instances=2, # NOTE: remove to run on a single node
)
train_model_task = DAGTask("TRAIN_MODEL", definition=train_job_definition)
evaluate_model = DAGTaskBranch(
"check_model_quality", definition=check_model_quality
)
Expand All @@ -372,7 +327,7 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
cleanup_task = DAGTask("cleanup_task", definition=cleanup, is_finalizer=True)

# Build the DAG
prepare_data >> train_model >> evaluate_model >> [promote_model_task, alert_task]
prepare_data >> train_model_task >> evaluate_model >> [promote_model_task, alert_task]

return dag

Expand Down
33 changes: 22 additions & 11 deletions samples/ml/ml_jobs/e2e_task_graph/src/pipeline_local.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from dataclasses import asdict
import json
import logging
from datetime import datetime
from snowflake.ml import jobs
from snowflake.snowpark import Session

import cloudpickle as cp
import modeling
from constants import DATA_TABLE_NAME
from constants import COMPUTE_POOL, DATA_TABLE_NAME, JOB_STAGE

logging.getLogger().setLevel(logging.ERROR)

Expand Down Expand Up @@ -39,23 +43,30 @@ def run_pipeline(
create_assets=False,
force_refresh=force_refresh,
)

dataset_info = {
"full": asdict(ds.read.data_sources[0]),
"train": asdict(train_ds.read.data_sources[0]),
"test": asdict(test_ds.read.data_sources[0]),
}
print("Training model...")
model_obj = modeling.train_model(session, train_ds.read.data_sources[0]).result()

print("Evaluating model...")
train_metrics = modeling.evaluate_model(
session, model_obj, train_ds.read.data_sources[0], prefix="train"
)
test_metrics = modeling.evaluate_model(
session, model_obj, test_ds.read.data_sources[0], prefix="test"

job = jobs.submit_directory(
dir_path="./src",
entrypoint="train_model.py",
compute_pool=COMPUTE_POOL,
stage_name=JOB_STAGE,
session=session,
args=["--dataset-info", json.dumps(dataset_info)],
target_instances=2,
)
metrics = {**train_metrics, **test_metrics}

job.wait()
model_obj = job.result()["model_obj"]
metrics = job.result()["metrics"]
key_metric = "test_accuracy"
threshold = 0.7
current_score = metrics[key_metric]
print(f"Current score: {current_score}. Threshold for promotion: {threshold}.")

if no_register:
print("Model registration disabled via --no-register flag.")
Expand Down
69 changes: 69 additions & 0 deletions samples/ml/ml_jobs/e2e_task_graph/src/train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from snowflake.ml.data import DatasetInfo
from snowflake.core.task.context import TaskContext
from snowflake.snowpark import Session
import os
import json
import cloudpickle as cp
import io
import argparse

from pipeline_dag import RunConfig
from modeling import evaluate_model, train_model

session = Session.builder.getOrCreate()


if __name__ == "__main__":
index = int(os.environ.get("SNOWFLAKE_JOB_INDEX", 0))

# Only head node saves and returns results
if index != 0:
print(f"Worker node (index {index}) - exiting")
exit(0)
artifact_dir = None
try:
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)
artifact_dir = config.artifact_dir

# Load the datasets
serialized = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))

except Exception as e:
print(f"Error loading dataset info: {e}")
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-info", type=str, required=True)
args = parser.parse_args()
serialized = json.loads(args.dataset_info)

dataset_info = {
key: DatasetInfo(**obj_dict) for key, obj_dict in serialized.items()
}
model_obj = train_model(session, dataset_info["train"])

if not hasattr(model_obj, 'feature_weights'):
model_obj.feature_weights = None
train_metrics = evaluate_model(
session, model_obj, dataset_info["train"], prefix="train"
)
test_metrics = evaluate_model(
session, model_obj, dataset_info["test"], prefix="test"
)
metrics = {**train_metrics, **test_metrics}
if artifact_dir:
model_pkl = cp.dumps(model_obj)
model_path = os.path.join(config.artifact_dir, "model.pkl")
put_result = session.file.put_stream(
io.BytesIO(model_pkl), model_path, overwrite=True
)
result_dict = {
"model_path": os.path.join(config.artifact_dir, put_result.target),
"metrics": metrics,
}
ctx.set_return_value(json.dumps(result_dict))
else:
result_dict = {
"model_obj": model_obj,
"metrics": metrics,
}
__return__= result_dict