Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions samples/ml/ml_jobs/e2e_task_graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ Run the ML pipeline locally without task graph orchestration:

```bash
python src/pipeline_local.py
python src/pipeline_local.py --no-register # Skip model registration for faster experimentation
```

You can monitor the corresponding ML Job for model training via the [Job UI in Snowsight](../README.md#job-ui-in-snowsight).
Expand Down Expand Up @@ -193,7 +192,7 @@ The `train_model` function uses the `@remote` decorator to run multi-node traini

```python
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
def train_model(input_data: DataSource) -> Optional[str]:
# Training logic runs on distributed compute
```

Expand All @@ -213,6 +212,6 @@ def check_model_quality(session: Session) -> str:
Successful models are automatically registered and promoted to production:

```python
mv = register_model(session, model, model_name, version, train_ds, metrics)
# get model version from train model
promote_model(session, mv) # Sets as default version
```
20 changes: 10 additions & 10 deletions samples/ml/ml_jobs/e2e_task_graph/src/modeling.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
from typing import Optional

import cloudpickle as cp
import data
import ops
from constants import (
COMPUTE_POOL,
DAG_STAGE,
DB_NAME,
JOB_STAGE,
Expand All @@ -18,7 +17,6 @@
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from snowflake.ml.data import DataConnector, DatasetInfo, DataSource
from snowflake.ml.dataset import Dataset, load_dataset
from snowflake.ml.jobs import remote
from snowflake.ml.model import ModelVersion
from snowflake.snowpark import Session
from snowflake.snowpark.exceptions import SnowparkSQLException
Expand Down Expand Up @@ -144,10 +142,7 @@ def prepare_datasets(
return (ds, train_ds, test_ds)


# NOTE: Remove `target_instances=2` to run training on a single node
# See https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/distributed-ml-jobs
@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2)
Comment on lines -147 to -149
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the main points of this sample is to demonstrate how easy it is to convert a local pipeline to pushing certain steps down into ML Jobs. Needing to write a separate script file which we submit_file() just for this conversion severely weakens this story. Why can't we just keep using a @remote() decorated function? @remote(...) should convert the function into an MLJobDefinition which we can directly use in pipeline_dag without needing an explicit MLJobDefinition.register() call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is currently @remote does not create job definition and it creates a job directly. Currently, we only merged the PR for phase one and phase 2 is in review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hold off on merging this until @remote is ready then

def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
def train_model(session: Session, input_data: Optional[DataSource] = None) -> XGBClassifier:
"""
Train a model on the training dataset.

Expand All @@ -163,6 +158,7 @@ def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
XGBClassifier: Trained XGBoost classifier model
"""
input_data_df = DataConnector.from_sources(session, [input_data]).to_pandas()


assert isinstance(input_data, DatasetInfo), "Input data must be a DatasetInfo"
exclude_cols = input_data.exclude_cols
Expand Down Expand Up @@ -195,8 +191,7 @@ def train_model(session: Session, input_data: DataSource) -> XGBClassifier:
estimator = XGBClassifier(**model_params)

estimator.fit(X_train, y_train)

# Convert distributed estimator to standard XGBClassifier if needed

return getattr(estimator, '_sklearn_estimator', estimator)


Expand Down Expand Up @@ -232,7 +227,12 @@ def evaluate_model(

X_test = input_data_df.drop(exclude_cols, axis=1)
expected = input_data_df[label_col].squeeze()
actual = model.predict(X_test)
# inside evaluate_model
if isinstance(model, ModelVersion):
preds_df = model.run(X_test, function_name="predict")
actual = preds_df.iloc[:, -1]
else:
actual = model.predict(X_test)

metric_types = [
f1_score,
Expand Down
4 changes: 4 additions & 0 deletions samples/ml/ml_jobs/e2e_task_graph/src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,7 @@ def promote_model(
# Set model as default
base_model = registry.get_model(model.model_name)
base_model.default = model

def get_model(session: Session, model_name: str, version_name: str) -> ModelVersion:
registry = get_model_registry(session)
return registry.get_model(model_name).version(version_name)
182 changes: 63 additions & 119 deletions samples/ml/ml_jobs/e2e_task_graph/src/pipeline_dag.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
import io
import json
import os
import time
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from dataclasses import asdict
from datetime import timedelta
from typing import Any, Optional
import uuid

import cloudpickle as cp
from snowflake.core import CreateMode, Root
from snowflake.core.task.context import TaskContext
from snowflake.core.task.dagv1 import DAG, DAGOperation, DAGTask, DAGTaskBranch
from snowflake.ml.data import DatasetInfo
from snowflake.ml.dataset import load_dataset
from snowflake.ml.jobs import MLJob
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark import Session
from snowflake.ml.jobs import remote
import modeling
import data
import ops
import run_config

import cli_utils
import data
import modeling
from constants import (DAG_STAGE, DATA_TABLE_NAME, DB_NAME, SCHEMA_NAME,
from constants import (COMPUTE_POOL, DAG_STAGE, DATA_TABLE_NAME, DB_NAME, JOB_STAGE, SCHEMA_NAME,
WAREHOUSE)

ARTIFACT_DIR = "run_artifacts"
# Ensure local modules are bundled for remote job execution.
cp.register_pickle_by_value(modeling)
cp.register_pickle_by_value(data)
cp.register_pickle_by_value(ops)
cp.register_pickle_by_value(run_config)


session = Session.builder.getOrCreate()
def _ensure_environment(session: Session):
"""
Ensure the environment is properly set up for DAG execution.
Expand All @@ -41,7 +48,6 @@ def _ensure_environment(session: Session):
_ = data.get_raw_data(session, DATA_TABLE_NAME, create_if_not_exists=True)

# Register local modules for inclusion in ML Job payloads
cp.register_pickle_by_value(modeling)


def _wait_for_run_to_complete(session: Session, dag: DAG) -> str:
Expand Down Expand Up @@ -108,49 +114,6 @@ def _wait_for_run_to_complete(session: Session, dag: DAG) -> str:

return dag_result


@dataclass(frozen=True)
class RunConfig:
run_id: str
dataset_name: str
model_name: str
metric_name: str
metric_threshold: float

@property
def artifact_dir(self) -> str:
return os.path.join(DAG_STAGE, ARTIFACT_DIR, self.run_id)

@classmethod
def from_task_context(cls, ctx: TaskContext, **kwargs: Any) -> "RunConfig":
run_schedule = ctx.get_current_task_graph_original_schedule()
run_id = "v" + (
run_schedule.strftime("%Y%m%d_%H%M%S")
if isinstance(run_schedule, datetime)
else str(run_schedule)
)
run_config = dict(run_id=run_id)

graph_config = ctx.get_task_graph_config()
merged = run_config | graph_config | kwargs

# Get expected fields from RunConfig
expected_fields = set(cls.__annotations__)

# Find unexpected keys
unexpected_keys = [key for key in merged.keys() if key not in expected_fields]
for key in unexpected_keys:
print(f"Warning: Unexpected config key '{key}' will be ignored")

filtered = {k: v for k, v in merged.items() if k in expected_fields}
return cls(**filtered)

@classmethod
def from_session(cls, session: Session) -> "RunConfig":
ctx = TaskContext(session)
return cls.from_task_context(ctx)


def prepare_datasets(session: Session) -> str:
"""
DAG task to prepare datasets for model training.
Expand All @@ -166,7 +129,7 @@ def prepare_datasets(session: Session) -> str:
str: JSON string containing serialized dataset information for downstream tasks
"""
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)
config = run_config.RunConfig.from_task_context(ctx)

ds, train_ds, test_ds = modeling.prepare_datasets(
session, DATA_TABLE_NAME, config.dataset_name
Expand All @@ -179,36 +142,48 @@ def prepare_datasets(session: Session) -> str:
}
return json.dumps(dataset_info)

@remote(COMPUTE_POOL, stage_name=JOB_STAGE)
def train_model(dataset_info: Optional[str] = None) -> Optional[str]:
session = Session.builder.getOrCreate()
ctx = None
config = None

def train_model(session: Session) -> str:
"""
DAG task to train a machine learning model.

This function is executed as part of the DAG workflow to train a model using the prepared datasets.
It retrieves dataset information from the previous task, trains the model, evaluates it on both
training and test sets, and saves the model to a stage for later use.

Args:
session (Session): Snowflake session object
if dataset_info:
dataset_info_dicts = json.loads(dataset_info)
try:
ctx = TaskContext(session)
print("ctx", ctx)
config = run_config.RunConfig.from_task_context(ctx)
dataset_info_dicts = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
except SnowparkSQLException:
print("there is no predecessor return value, fallback to local mode")

datasets = {
key: DatasetInfo(**info_dict) for key, info_dict in dataset_info_dicts.items()
}
train_ds=load_dataset(
session,
datasets["full"].fully_qualified_name,
datasets["full"].version,
)
model_obj = modeling.train_model(session, datasets["train"])
train_metrics = modeling.evaluate_model(
session, model_obj, train_ds.read.data_sources[0], prefix="train"
)
version = f"v{uuid.uuid4().hex}"
mv = modeling.register_model(session, model_obj, config.model_name if config and config.model_name else "mortgage_model", version, train_ds, metrics={}) if config else modeling.register_model(session, model_obj, "mortgage_model", version, train_ds, metrics=train_metrics)
if ctx and config:
ctx.set_return_value(json.dumps({"model_name": mv.fully_qualified_model_name, "version_name": mv.version_name}))
return json.dumps({"model_name": mv.fully_qualified_model_name, "version_name": mv.version_name})

Returns:
str: JSON string containing model path and evaluation metrics
"""
def evaluate_model(session: Session) -> Optional[str]:
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)

# Load the datasets
serialized = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
dataset_info = {
key: DatasetInfo(**obj_dict) for key, obj_dict in serialized.items()
}

# Train the model
model = modeling.train_model(session, dataset_info["train"])
if isinstance(model, MLJob):
model = model.result()

# Evaluate the model
model_info = json.loads(ctx.get_predecessor_return_value("TRAIN_MODEL"))
model = ops.get_model(session, model_info["model_name"], model_info["version_name"])
train_metrics = modeling.evaluate_model(
session, model, dataset_info["train"], prefix="train"
)
Expand All @@ -217,18 +192,7 @@ def train_model(session: Session) -> str:
)
metrics = {**train_metrics, **test_metrics}

# Save model to stage and return the metrics as a JSON string
model_pkl = cp.dumps(model)
model_path = os.path.join(config.artifact_dir, "model.pkl")
put_result = session.file.put_stream(
io.BytesIO(model_pkl), model_path, overwrite=True
)

result_dict = {
"model_path": os.path.join(config.artifact_dir, put_result.target),
"metrics": metrics,
}
return json.dumps(result_dict)
return json.dumps(metrics)


def check_model_quality(session: Session) -> str:
Expand All @@ -246,11 +210,10 @@ def check_model_quality(session: Session) -> str:
str: "promote_model" if model meets threshold, "send_alert" otherwise
"""
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)
config = run_config.RunConfig.from_task_context(ctx)

metrics = json.loads(ctx.get_predecessor_return_value("TRAIN_MODEL"))["metrics"]
metrics = json.loads(ctx.get_predecessor_return_value("EVALUATE_MODEL"))

# If model is good, promote model
threshold = config.metric_threshold
if metrics[config.metric_name] >= threshold:
return "promote_model"
Expand All @@ -273,28 +236,8 @@ def promote_model(session: Session) -> str:
str: Tuple of (fully_qualified_model_name, version_name) as string
"""
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)

train_result = json.loads(ctx.get_predecessor_return_value("TRAIN_MODEL"))
model_path = train_result["model_path"]
with session.file.get_stream(model_path, decompress=True) as stream:
model = cp.loads(stream.read())

serialized = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA"))
source_data = {key: DatasetInfo(**obj_dict) for key, obj_dict in serialized.items()}
mv = modeling.register_model(
session,
model,
model_name=config.model_name,
version_name=config.run_id,
train_ds=load_dataset(
session,
source_data["full"].fully_qualified_name,
source_data["full"].version,
),
metrics=train_result["metrics"],
)

model_info = json.loads(ctx.get_predecessor_return_value("TRAIN_MODEL"))
mv = ops.get_model(session, model_info["model_name"], model_info["version_name"])
modeling.promote_model(session, mv)

return (mv.fully_qualified_model_name, mv.version_name)
Expand All @@ -312,9 +255,8 @@ def cleanup(session: Session) -> None:
session (Session): Snowflake session object
"""
ctx = TaskContext(session)
config = RunConfig.from_task_context(ctx)
config = run_config.RunConfig.from_task_context(ctx)

session.sql(f"REMOVE {config.artifact_dir}").collect()
modeling.clean_up(session, config.dataset_name, config.model_name)


Expand All @@ -341,7 +283,7 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
schedule=schedule,
use_func_return_value=True,
stage_location=DAG_STAGE,
packages=["snowflake-snowpark-python", "snowflake-ml-python<1.9.0", "xgboost"], # NOTE: Temporarily pinning to <1.9.0 due to compatibility issues
packages=["snowflake-snowpark-python", "snowflake-ml-python", "xgboost"],
config={
"dataset_name": "mortgage_dataset",
"model_name": "mortgage_model",
Expand All @@ -352,7 +294,9 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
) as dag:
# Need to wrap first function in a DAGTask to make >> operator work properly
prepare_data = DAGTask("prepare_data", definition=prepare_datasets)
evaluate_model = DAGTaskBranch(
train_model_task = DAGTask("TRAIN_MODEL", definition=train_model)
evaluate_model_task = DAGTask("EVALUATE_MODEL", definition=evaluate_model)
check_model_quality_task = DAGTaskBranch(
"check_model_quality", definition=check_model_quality
)
promote_model_task = DAGTask("promote_model", definition=promote_model)
Expand All @@ -372,7 +316,7 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) -
cleanup_task = DAGTask("cleanup_task", definition=cleanup, is_finalizer=True)

# Build the DAG
prepare_data >> train_model >> evaluate_model >> [promote_model_task, alert_task]
prepare_data >> train_model_task >> evaluate_model_task >> check_model_quality_task >> [promote_model_task, alert_task]

return dag

Expand Down
Loading