From c7fd3e0e13ad0b43fc94300a22179f40a01f4f3b Mon Sep 17 00:00:00 2001 From: smouaa Date: Mon, 27 Oct 2025 08:44:48 +0000 Subject: [PATCH 1/3] Added support for Sagemaker ENV variables + custom formatter scripts --- .../python/setup/djl_python/input_parser.py | 22 + .../setup/djl_python/sklearn_handler.py | 132 +++-- .../djl_python/tests/test_encode_decode.py | 3 +- engines/python/setup/djl_python/utils.py | 69 ++- .../setup/djl_python/xgboost_handler.py | 137 +++-- serving/build.gradle.kts | 3 + .../ai/djl/serving/util/ConfigManager.java | 12 +- tests/integration/download_models.sh | 16 + tests/integration/test_custom_formatters.py | 543 ++++++++++++++++++ .../test_sagemaker_compatibility.py | 279 +++++++++ tests/integration/test_xgb_skl.py | 60 +- .../java/ai/djl/serving/wlm/ModelInfo.java | 6 + .../wlm/util/SageMakerCompatibility.java | 123 ++++ 13 files changed, 1274 insertions(+), 131 deletions(-) create mode 100644 tests/integration/test_custom_formatters.py create mode 100644 tests/integration/test_sagemaker_compatibility.py create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/util/SageMakerCompatibility.java diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index 436ad8fdb..bc54066ec 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -36,6 +36,28 @@ def input_formatter(function): return function +def predict_formatter(function): + """ + Decorator for predict_formatter. User just need to annotate @predict_formatter for their custom defined function. + :param function: Decorator takes in the function and adds an attribute. + :return: + """ + # adding an attribute to the function, which is used to find the decorated function. + function.is_predict_formatter = True + return function + + +def model_loading_formatter(function): + """ + Decorator for model_loading_formatter. User just need to annotate @model_loading_formatter for their custom defined function. + :param function: Decorator takes in the function and adds an attribute. + :return: + """ + # adding an attribute to the function, which is used to find the decorated function. + function.is_model_loading_formatter = True + return function + + @dataclass class ParsedInput: errors: dict = field(default_factory=lambda: {}) diff --git a/engines/python/setup/djl_python/sklearn_handler.py b/engines/python/setup/djl_python/sklearn_handler.py index fa9d4b03b..1fd66027f 100644 --- a/engines/python/setup/djl_python/sklearn_handler.py +++ b/engines/python/setup/djl_python/sklearn_handler.py @@ -18,7 +18,7 @@ from typing import Optional from djl_python import Input, Output from djl_python.encode_decode import decode -from djl_python.utils import find_model_file +from djl_python.utils import find_model_file, get_sagemaker_function from djl_python.service_loader import get_annotated_function from djl_python.import_utils import joblib, cloudpickle, skops_io as sio @@ -31,6 +31,49 @@ def __init__(self): self.custom_input_formatter = None self.custom_output_formatter = None self.custom_predict_formatter = None + self.custom_model_loading_formatter = None + self.init_properties = None + self.is_sagemaker_script = False + + def _load_custom_formatters(self, model_dir: str): + """Load custom formatters, checking DJL decorators first, then SageMaker functions.""" + # Check for DJL decorator-based custom formatters first + self.custom_model_loading_formatter = get_annotated_function( + model_dir, "is_model_loading_formatter") + self.custom_input_formatter = get_annotated_function( + model_dir, "is_input_formatter") + self.custom_output_formatter = get_annotated_function( + model_dir, "is_output_formatter") + self.custom_predict_formatter = get_annotated_function( + model_dir, "is_predict_formatter") + + # If no decorator-based formatters found, check for SageMaker-style formatters + if not any([ + self.custom_input_formatter, self.custom_output_formatter, + self.custom_predict_formatter, + self.custom_model_loading_formatter + ]): + + sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn') + sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn') + sagemaker_predict_fn = get_sagemaker_function( + model_dir, 'predict_fn') + sagemaker_output_fn = get_sagemaker_function( + model_dir, 'output_fn') + + if any([ + sagemaker_model_fn, sagemaker_input_fn, + sagemaker_predict_fn, sagemaker_output_fn + ]): + self.is_sagemaker_script = True + if sagemaker_model_fn: + self.custom_model_loading_formatter = sagemaker_model_fn + if sagemaker_input_fn: + self.custom_input_formatter = sagemaker_input_fn + if sagemaker_predict_fn: + self.custom_predict_formatter = sagemaker_predict_fn + if sagemaker_output_fn: + self.custom_output_formatter = sagemaker_output_fn def _get_trusted_types(self, properties: dict): trusted_types_str = properties.get("skops_trusted_types", "") @@ -46,6 +89,8 @@ def _get_trusted_types(self, properties: dict): return trusted_types def initialize(self, properties: dict): + # Store initialization properties for use during inference + self.init_properties = properties.copy() model_dir = properties.get("model_dir") model_format = properties.get("model_format", "skops") @@ -62,43 +107,51 @@ def initialize(self, properties: dict): f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle" ) - model_file = find_model_file(model_dir, extensions) - if not model_file: - raise FileNotFoundError( - f"No model file found with format '{model_format}' in {model_dir}" - ) + # Load custom formatters + self._load_custom_formatters(model_dir) - if model_format == "skops": - trusted_types = self._get_trusted_types(properties) - self.model = sio.load(model_file, trusted=trusted_types) + # Load model + if self.custom_model_loading_formatter: + self.model = self.custom_model_loading_formatter(model_dir) else: - if properties.get("trust_insecure_model_files", - "false").lower() != "true": - raise ValueError( - f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)" + model_file = find_model_file(model_dir, extensions) + if not model_file: + raise FileNotFoundError( + f"No model file found with format '{model_format}' in {model_dir}" ) - if model_format == "joblib": - self.model = joblib.load(model_file) - elif model_format == "pickle": - with open(model_file, 'rb') as f: - self.model = pickle.load(f) - elif model_format == "cloudpickle": - with open(model_file, 'rb') as f: - self.model = cloudpickle.load(f) - - self.custom_input_formatter = get_annotated_function( - model_dir, "is_input_formatter") - self.custom_output_formatter = get_annotated_function( - model_dir, "is_output_formatter") - self.custom_predict_formatter = get_annotated_function( - model_dir, "is_predict_formatter") + if model_format == "skops": + trusted_types = self._get_trusted_types(properties) + self.model = sio.load(model_file, trusted=trusted_types) + else: + if properties.get("trust_insecure_model_files", + "false").lower() != "true": + raise ValueError( + f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)" + ) + + if model_format == "joblib": + self.model = joblib.load(model_file) + elif model_format == "pickle": + with open(model_file, 'rb') as f: + self.model = pickle.load(f) + elif model_format == "cloudpickle": + with open(model_file, 'rb') as f: + self.model = cloudpickle.load(f) self.initialized = True def inference(self, inputs: Input) -> Output: content_type = inputs.get_property("Content-Type") - accept = inputs.get_property("Accept") or "application/json" + properties = inputs.get_properties() + default_accept = self.init_properties.get("default_accept", + "application/json") + + accept = inputs.get_property("Accept") + + # If no accept type is specified in the request, use default + if accept == "*/*": + accept = default_accept # Validate accept type (skip validation if custom output formatter is provided) if not self.custom_output_formatter: @@ -112,7 +165,11 @@ def inference(self, inputs: Input) -> Output: # Input processing X = None if self.custom_input_formatter: - X = self.custom_input_formatter(inputs) + if self.is_sagemaker_script: + X = self.custom_input_formatter(inputs.get_as_bytes(), + content_type) + else: + X = self.custom_input_formatter(inputs) elif "text/csv" in content_type: X = decode(inputs, content_type, require_csv_headers=False) else: @@ -129,17 +186,20 @@ def inference(self, inputs: Input) -> Output: X = X.reshape(1, -1) if self.custom_predict_formatter: - predictions = self.custom_predict_formatter(self.model, X) + predictions = self.custom_predict_formatter(X, self.model) else: predictions = self.model.predict(X) # Output processing - if self.custom_output_formatter: - return self.custom_output_formatter(predictions) - - # Supports CSV/JSON outputs by default outputs = Output() - if "text/csv" in accept: + if self.custom_output_formatter: + if self.is_sagemaker_script: + data = self.custom_output_formatter(predictions, accept) + outputs.add_property("Content-Type", accept) + else: + data = self.custom_output_formatter(predictions) + outputs.add(data) + elif "text/csv" in accept: csv_buffer = StringIO() np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',') outputs.add(csv_buffer.getvalue().rstrip()) diff --git a/engines/python/setup/djl_python/tests/test_encode_decode.py b/engines/python/setup/djl_python/tests/test_encode_decode.py index c253e00f0..183c54684 100644 --- a/engines/python/setup/djl_python/tests/test_encode_decode.py +++ b/engines/python/setup/djl_python/tests/test_encode_decode.py @@ -115,7 +115,8 @@ def test_decode_text_csv(self): mock_decode_csv.return_value = {"inputs": ["test input"]} result = decode(self.mock_input, "text/csv") - mock_decode_csv.assert_called_once_with(self.mock_input) + mock_decode_csv.assert_called_once_with(self.mock_input, + require_headers=True) self.assertEqual(result, {"inputs": ["test input"]}) def test_decode_text_plain(self): diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index 4b2a4c3f7..c27764aaa 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -13,10 +13,21 @@ import glob import logging import os -from typing import Optional, List +import inspect +import importlib.util +from typing import Optional, List, Callable from djl_python import Output from djl_python.inputs import Input +from djl_python.service_loader import load_model_service, has_function_in_module + +# SageMaker function signatures for validation +SAGEMAKER_SIGNATURES = { + 'model_fn': ['model_dir'], + 'input_fn': ['request_body', 'content_type'], + 'predict_fn': ['input_data', 'model'], + 'output_fn': ['prediction', 'accept'] +} class IdCounter: @@ -188,3 +199,59 @@ def find_model_file(model_dir: str, extensions: List[str]) -> Optional[str]: ) return all_matches[0] if all_matches else None + + +def _validate_sagemaker_function( + module, func_name: str, + expected_params: List[str]) -> Optional[Callable]: + """ + Validate that function exists and has correct signature + Returns the function if valid, None otherwise + """ + if not hasattr(module, func_name): + return None + + func = getattr(module, func_name) + if not callable(func): + return None + + try: + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + + # Check parameter count and names match exactly + if param_names == expected_params: + return func + except (ValueError, TypeError): + # Handle cases where signature inspection fails + pass + + return None + + +def get_sagemaker_function(model_dir: str, + func_name: str) -> Optional[Callable]: + """ + Load and validate SageMaker-style formatter function from model.py + + :param model_dir: model directory containing model.py + :param func_name: SageMaker function name (model_fn, input_fn, predict_fn, output_fn) + :return: Validated function or None if not found/invalid + """ + + if func_name not in SAGEMAKER_SIGNATURES: + return None + + try: + service = load_model_service(model_dir, "model.py", -1) + if has_function_in_module(service.module, func_name): + func = getattr(service.module, func_name) + # Optional: validate signature + expected_params = SAGEMAKER_SIGNATURES[func_name] + if _validate_sagemaker_function(service.module, func_name, + expected_params): + return func + + except Exception as e: + logging.debug(f"Failed to load {func_name} from model.py: {e}") + return None diff --git a/engines/python/setup/djl_python/xgboost_handler.py b/engines/python/setup/djl_python/xgboost_handler.py index 0f08508aa..483b2e939 100644 --- a/engines/python/setup/djl_python/xgboost_handler.py +++ b/engines/python/setup/djl_python/xgboost_handler.py @@ -18,7 +18,7 @@ from typing import Optional from djl_python import Input, Output from djl_python.encode_decode import decode -from djl_python.utils import find_model_file +from djl_python.utils import find_model_file, get_sagemaker_function from djl_python.service_loader import get_annotated_function from djl_python.import_utils import xgboost as xgb @@ -28,8 +28,56 @@ class XGBoostHandler: def __init__(self): self.model = None self.initialized = False + self.custom_input_formatter = None + self.custom_output_formatter = None + self.custom_predict_formatter = None + self.custom_model_loading_formatter = None + self.init_properties = None + self.is_sagemaker_script = False + + def _load_custom_formatters(self, model_dir: str): + """Load custom formatters, checking DJL decorators first, then SageMaker functions.""" + # Check for DJL decorator-based custom formatters first + self.custom_model_loading_formatter = get_annotated_function( + model_dir, "is_model_loading_formatter") + self.custom_input_formatter = get_annotated_function( + model_dir, "is_input_formatter") + self.custom_output_formatter = get_annotated_function( + model_dir, "is_output_formatter") + self.custom_predict_formatter = get_annotated_function( + model_dir, "is_predict_formatter") + + # If no decorator-based formatters found, check for SageMaker-style formatters + if not any([ + self.custom_input_formatter, self.custom_output_formatter, + self.custom_predict_formatter, + self.custom_model_loading_formatter + ]): + + sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn') + sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn') + sagemaker_predict_fn = get_sagemaker_function( + model_dir, 'predict_fn') + sagemaker_output_fn = get_sagemaker_function( + model_dir, 'output_fn') + + if any([ + sagemaker_model_fn, sagemaker_input_fn, + sagemaker_predict_fn, sagemaker_output_fn + ]): + self.is_sagemaker_script = True + if sagemaker_model_fn: + self.custom_model_loading_formatter = sagemaker_model_fn + if sagemaker_input_fn: + self.custom_input_formatter = sagemaker_input_fn + if sagemaker_predict_fn: + self.custom_predict_formatter = sagemaker_predict_fn + if sagemaker_output_fn: + self.custom_output_formatter = sagemaker_output_fn def initialize(self, properties: dict): + # Store initialization properties for use during inference + self.init_properties = properties.copy() model_dir = properties.get("model_dir") model_format = (properties.get("model_format") or os.environ.get("MODEL_FORMAT") or "json") @@ -47,42 +95,51 @@ def initialize(self, properties: dict): f"Unsupported model format: {model_format}. Supported formats: json, ubj, pickle, xgb" ) - model_file = find_model_file(model_dir, extensions) - if not model_file: - raise FileNotFoundError( - f"No model file found with format '{model_format}' in {model_dir}" - ) + # Load custom formatters + self._load_custom_formatters(model_dir) - if model_format in ["json", "ubj"]: - self.model = xgb.Booster() - self.model.load_model(model_file) - else: # unsafe formats: pickle, xgb - trust_insecure = (properties.get("trust_insecure_model_files") - or os.environ.get("TRUST_INSECURE_MODEL_FILES") - or "false") - if trust_insecure.lower() != "true": - raise ValueError( - "option.trust_insecure_model_files must be set to 'true' to use unsafe formats (only json/ubj are secure by default)" + if self.custom_model_loading_formatter: + self.model = self.custom_model_loading_formatter(model_dir) + else: + model_file = find_model_file(model_dir, extensions) + if not model_file: + raise FileNotFoundError( + f"No model file found with format '{model_format}' in {model_dir}" ) - if model_format == "pickle": - with open(model_file, 'rb') as f: - self.model = pkl.load(f) - else: # xgb format + + if model_format in ["json", "ubj"]: self.model = xgb.Booster() self.model.load_model(model_file) - - self.custom_input_formatter = get_annotated_function( - model_dir, "is_input_formatter") - self.custom_output_formatter = get_annotated_function( - model_dir, "is_output_formatter") - self.custom_predict_formatter = get_annotated_function( - model_dir, "is_predict_formatter") + else: # unsafe formats: pickle, xgb + trust_insecure = (properties.get("trust_insecure_model_files") + or + os.environ.get("TRUST_INSECURE_MODEL_FILES") + or "false") + if trust_insecure.lower() != "true": + raise ValueError( + "option.trust_insecure_model_files must be set to 'true' to use unsafe formats (only json/ubj are secure by default)" + ) + if model_format == "pickle": + with open(model_file, 'rb') as f: + self.model = pkl.load(f) + else: # xgb format + self.model = xgb.Booster() + self.model.load_model(model_file) self.initialized = True def inference(self, inputs: Input) -> Output: content_type = inputs.get_property("Content-Type") - accept = inputs.get_property("Accept") or "application/json" + properties = inputs.get_properties() + # Use initialization properties as fallback for missing request properties + default_accept = self.init_properties.get("default_accept", + "application/json") + + accept = inputs.get_property("Accept") + + # Treat */* as no preference, use default + if accept == "*/*": + accept = default_accept # Validate accept type (skip validation if custom output formatter is provided) if not self.custom_output_formatter: @@ -96,7 +153,11 @@ def inference(self, inputs: Input) -> Output: # Input processing X = None if self.custom_input_formatter: - X = self.custom_input_formatter(inputs) + if self.is_sagemaker_script: + X = self.custom_input_formatter(inputs.get_as_bytes(), + content_type) + else: + X = self.custom_input_formatter(inputs) elif "text/csv" in content_type: X = decode(inputs, content_type, require_csv_headers=False) else: @@ -112,18 +173,22 @@ def inference(self, inputs: Input) -> Output: if X.ndim == 1: X = X.reshape(1, -1) if self.custom_predict_formatter: - predictions = self.custom_predict_formatter(self.model, X) + predictions = self.custom_predict_formatter(X, self.model) else: dmatrix = xgb.DMatrix(X) - predictions = self.model.predict(dmatrix, validate_features=False) + predictions = self.model.predict(dmatrix) # Output processing - if self.custom_output_formatter: - return self.custom_output_formatter(predictions) - - # Supports CSV/JSON outputs by default outputs = Output() - if "text/csv" in accept: + if self.custom_output_formatter: + if self.is_sagemaker_script: + data = self.custom_output_formatter(predictions, accept) + outputs.add_property("Content-Type", accept) + else: + data = self.custom_output_formatter(predictions) + outputs.add(data) + + elif "text/csv" in accept: csv_buffer = StringIO() np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',') outputs.add(csv_buffer.getvalue().rstrip()) diff --git a/serving/build.gradle.kts b/serving/build.gradle.kts index 1f3ae41fa..370ed2415 100644 --- a/serving/build.gradle.kts +++ b/serving/build.gradle.kts @@ -147,6 +147,9 @@ tasks { "if [ \"\${MODEL_SERVER_HOME}\" = \"\" ] ; then\n" + " export MODEL_SERVER_HOME=\${APP_HOME}\n" + "fi\n" + + "if [ \"\${SAGEMAKER_MODEL_SERVER_VMARGS}\" != \"\" ] ; then\n" + + " export JAVA_OPTS=\"\$JAVA_OPTS \$SAGEMAKER_MODEL_SERVER_VMARGS\"\n" + + "fi\n" + "if [ -f \"/opt/ml/.sagemaker_infra/endpoint-metadata.json\" ]; then\n" + " export JAVA_OPTS=\"\$JAVA_OPTS -XX:-UseContainerSupport\"\n" + " DEFAULT_JVM_OPTS=\"\${DEFAULT_JVM_OPTS:--Dlog4j.configurationFile=\${APP_HOME}/conf/log4j2-plain.xml}\"\n" + diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index eb453980c..ec30ad371 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -13,6 +13,7 @@ package ai.djl.serving.util; import ai.djl.serving.Arguments; +import ai.djl.serving.wlm.util.SageMakerCompatibility; import ai.djl.serving.wlm.util.WlmConfigManager; import ai.djl.util.Ec2Utils; import ai.djl.util.NeuronUtils; @@ -127,10 +128,15 @@ private ConfigManager(Arguments args) { if (models != null) { prop.setProperty(LOAD_MODELS, String.join(",", models)); } - for (Map.Entry env : Utils.getenv().entrySet()) { - String key = env.getKey(); + // Apply SageMaker compatibility for server-level configurations + SageMakerCompatibility.applyServerCompatibility(prop); + + Map env = Utils.getenv(); + + for (Map.Entry entry : env.entrySet()) { + String key = entry.getKey(); if (key.startsWith("SERVING_")) { - prop.put(key.substring(8).toLowerCase(Locale.ROOT), env.getValue()); + prop.put(key.substring(8).toLowerCase(Locale.ROOT), entry.getValue()); } } for (Map.Entry entry : prop.entrySet()) { diff --git a/tests/integration/download_models.sh b/tests/integration/download_models.sh index 03afd0932..f7d75dffd 100755 --- a/tests/integration/download_models.sh +++ b/tests/integration/download_models.sh @@ -35,6 +35,15 @@ python_skl_models_urls=( "https://resources.djl.ai/test-models/python/sklearn/sklearn_unsafe_model_v2.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_v2.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_skops_model_env_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_sm.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output_invalid.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_mixed_djl_sagemaker.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_all_formatters_v3.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_input_output_v3.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_invalid_input_v3.zip" + "https://resources.djl.ai/test-models/python/sklearn/slow_loading_model.zip" + "https://resources.djl.ai/test-models/python/sklearn/slow_predict_model.zip" ) python_xgb_models_urls=( @@ -43,6 +52,13 @@ python_xgb_models_urls=( "https://resources.djl.ai/test-models/python/xgboost/xgboost_deprecated_model_v2.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_unsafe_model_v2.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_custom_model_v2.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_all.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_input_output.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_input_output_invalid.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_mixed_djl_sagemaker.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_all_formatters_v3.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_input_output_v3.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_invalid_input_v3.zip" ) download() { diff --git a/tests/integration/test_custom_formatters.py b/tests/integration/test_custom_formatters.py new file mode 100644 index 000000000..fa1a022e7 --- /dev/null +++ b/tests/integration/test_custom_formatters.py @@ -0,0 +1,543 @@ +import os +import requests +import tempfile +import shutil +import zipfile +from tests import Runner + + +@pytest.mark.cpu +class TestCustomFormatters: + + def test_sklearn_all_custom_formatters(self): + """Test sklearn handler with all four custom formatters""" + with Runner('cpu-full', 'sklearn_custom_formatters', + download=True) as r: + r.launch( + cmd= + "serve -m sklearn_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm.zip" + ) + + # Test custom formatters + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_custom", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert result["probability"] == 0.999 # Custom predict formatter + assert "model_type" in result # Custom output formatter + + def test_sagemaker_env_with_custom_formatters(self): + """Test SageMaker compatibility with custom formatters and env variables""" + with Runner('cpu-full', 'sagemaker_custom_formatters', + download=True) as r: + env = [ + "SAGEMAKER_NUM_MODEL_WORKERS=2", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT=application/json" + ] + r.launch( + env_vars=env, + cmd= + "serve -m sagemaker_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm.zip" + ) + + # Test with custom formatters - use features format from existing model + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sagemaker_custom", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 200 + result = response.json() + assert result["probability"] == 0.999 # Custom predict formatter + assert "model_type" in result # Custom output formatter + + def test_sagemaker_csv_default_with_json_only_formatter(self): + """Test SageMaker with CSV default but JSON-only output formatter (should fail)""" + with Runner('cpu-full', 'sagemaker_csv_default', download=True) as r: + env = [ + "SAGEMAKER_NUM_MODEL_WORKERS=1", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT=text/csv" + ] + r.launch( + env_vars=env, + cmd= + "serve -m sagemaker_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm.zip" + ) + + # Test should fail because output_fn only supports application/json + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sagemaker_custom", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 424 # Failed dependency - output formatter error + + def test_sagemaker_input_output_formatters(self): + """Test input_fn and output_fn formatters -- should work without other two functions and also work + with default predict logic in handler""" + with Runner('cpu-full', 'sagemaker_input_output', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_io::Python=file:/opt/ml/model/sklearn_custom_model_input_output.zip" + ) + + # Test with SageMaker input/output formatters + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_io", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "original_prediction" in result + assert "doubled_prediction" in result + assert "prediction_sum" in result + assert result["sagemaker_output_fn_used"] == True + # Verify doubled prediction is actually double the original + original = result["original_prediction"][0] + doubled = result["doubled_prediction"][0] + assert doubled == original * 2 + + def test_sagemaker_invalid_input_formatter(self): + """Test SageMaker input_fn that returns invalid format (should fail with default handler predict)""" + with Runner('cpu-full', 'sagemaker_invalid_input', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_invalid::Python=file:/opt/ml/model/sklearn_custom_model_input_output_invalid.zip" + ) + + # Test should fail because input_fn returns raw list instead of numpy array + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_invalid", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 424 # Failed dependency - input processing error + + def test_xgboost_all_sagemaker_formatters(self): + """Test XGBoost handler with all four SageMaker formatters""" + with Runner('cpu-full', 'xgboost_sagemaker_all', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_all::Python=file:/opt/ml/model/xgboost_sagemaker_all.zip" + ) + + # Test SageMaker formatters + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_all", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert result["prediction"] == 0.888 # Custom predict formatter + assert result["custom_xgb"] == True # Custom output formatter + assert "model_type" in result + + def test_xgboost_sagemaker_env_with_formatters(self): + """Test XGBoost SageMaker compatibility with env variables""" + with Runner('cpu-full', 'xgboost_sagemaker_env', download=True) as r: + env = [ + "SAGEMAKER_NUM_MODEL_WORKERS=2", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT=application/json" + ] + r.launch( + env_vars=env, + cmd= + "serve -m xgboost_env::Python=file:/opt/ml/model/xgboost_sagemaker_all.zip" + ) + + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_env", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 200 + result = response.json() + assert result["prediction"] == 0.888 + assert result["custom_xgb"] == True + + def test_xgboost_csv_default_with_json_only_formatter(self): + """Test XGBoost SageMaker with CSV default but JSON-only output formatter (should fail)""" + with Runner('cpu-full', 'xgboost_csv_default', download=True) as r: + env = [ + "SAGEMAKER_NUM_MODEL_WORKERS=1", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT=text/csv" + ] + r.launch( + env_vars=env, + cmd= + "serve -m xgboost_csv::Python=file:/opt/ml/model/xgboost_sagemaker_all.zip" + ) + + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_csv", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 424 # Failed dependency - output formatter error + + def test_xgboost_sagemaker_input_output_formatters(self): + """Test XGBoost SageMaker input_fn and output_fn formatters""" + with Runner('cpu-full', 'xgboost_input_output', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_io::Python=file:/opt/ml/model/xgboost_sagemaker_input_output.zip" + ) + + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_io", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "original_prediction" in result + assert "doubled_prediction" in result + assert "prediction_sum" in result + assert result["sagemaker_output_fn_used"] == True + # Verify doubled prediction is actually double the original + original = result["original_prediction"][0] + doubled = result["doubled_prediction"][0] + assert doubled == original * 2 + + def test_xgboost_sagemaker_invalid_input_formatter(self): + """Test XGBoost SageMaker input_fn that returns invalid format (should fail)""" + with Runner('cpu-full', 'xgboost_invalid_input', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_invalid::Python=file:/opt/ml/model/xgboost_sagemaker_input_output_invalid.zip" + ) + + # Test should fail because input_fn returns raw list instead of numpy array + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_invalid", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 424 # Failed dependency - input processing error + + def test_sklearn_mixed_djl_sagemaker_formatters(self): + """Test sklearn with mixed DJL and SageMaker formatters - DJL should take precedence and SageMaker functions should be ignored""" + with Runner('cpu-full', 'sklearn_mixed_formatters', + download=True) as r: + r.launch( + cmd= + "serve -m sklearn_mixed::Python=file:/opt/ml/model/sklearn_mixed_djl_sagemaker.zip" + ) + + # When DJL decorators are present, SageMaker functions should be completely ignored + # This means only DJL model loader will be used, and default input/predict/output processing will be used + test_data = { + "inputs": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_mixed", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + # Should get default sklearn prediction output format, not custom SageMaker output + assert "predictions" in result # Default output format + assert isinstance(result["predictions"], list) + # Verify the model was loaded with DJL decorator (it has djl_loaded=True attribute) + + def test_xgboost_mixed_djl_sagemaker_formatters(self): + """Test xgboost with mixed DJL and SageMaker formatters - DJL should take precedence and SageMaker functions should be ignored""" + with Runner('cpu-full', 'xgboost_mixed_formatters', + download=True) as r: + r.launch( + cmd= + "serve -m xgboost_mixed::Python=file:/opt/ml/model/xgboost_mixed_djl_sagemaker.zip" + ) + + # When DJL decorators are present, SageMaker functions should be completely ignored + # This means only DJL model loader will be used, and default input/predict/output processing will be used + test_data = { + "inputs": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_mixed", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + # Should get default xgboost prediction output format, not custom SageMaker output + assert "predictions" in result # Default output format + assert isinstance(result["predictions"], list) + # Verify the model was loaded with DJL decorator (it has djl_loaded=True attribute) + + def test_sklearn_djl_all_formatters(self): + """Test sklearn handler with all four DJL decorators""" + with Runner('cpu-full', 'sklearn_djl_all', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_djl_all::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v3.zip" + ) + + # Test DJL decorators + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_djl_all", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert result["prediction"] == 0.777 # Custom predict formatter + assert result["custom_sklearn"] == True # Custom output formatter + assert result["formatter_type"] == "djl_decorators" + assert "model_type" in result + + def test_sklearn_djl_env_with_formatters(self): + """Test sklearn DJL compatibility with env variables""" + with Runner('cpu-full', 'sklearn_djl_env', download=True) as r: + env = [ + "SAGEMAKER_NUM_MODEL_WORKERS=2", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT=application/json" + ] + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_djl_env::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v3.zip" + ) + + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_djl_env", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 200 + result = response.json() + assert result["prediction"] == 0.777 + assert result["custom_sklearn"] == True + assert result["formatter_type"] == "djl_decorators" + + def test_sklearn_djl_input_output_formatters(self): + """Test sklearn DJL input_formatter and output_formatter decorators""" + with Runner('cpu-full', 'sklearn_djl_input_output', + download=True) as r: + r.launch( + cmd= + "serve -m sklearn_djl_io::Python=file:/opt/ml/model/sklearn_djl_input_output_v3.zip" + ) + + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_djl_io", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "original_prediction" in result + assert "doubled_prediction" in result + assert "prediction_sum" in result + assert result["djl_output_formatter_used"] == True + assert result["formatter_type"] == "djl_decorators" + # Verify doubled prediction is actually double the original + original = result["original_prediction"][0] + doubled = result["doubled_prediction"][0] + assert doubled == original * 2 + + def test_sklearn_djl_invalid_input_formatter(self): + """Test sklearn DJL input_formatter that returns invalid format (should fail)""" + with Runner('cpu-full', 'sklearn_djl_invalid_input', + download=True) as r: + r.launch( + cmd= + "serve -m sklearn_djl_invalid::Python=file:/opt/ml/model/sklearn_djl_invalid_input_v3.zip" + ) + + # Test should fail because input_formatter returns raw list instead of numpy array + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_djl_invalid", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 424 # Failed dependency - input processing error + + def test_xgboost_djl_all_formatters(self): + """Test XGBoost handler with all four DJL decorators""" + with Runner('cpu-full', 'xgboost_djl_all', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_djl_all::Python=file:/opt/ml/model/xgboost_djl_all_formatters_v3.zip" + ) + + # Test DJL decorators + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_djl_all", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert result["prediction"] == 0.999 # Custom predict formatter + assert result["custom_xgb"] == True # Custom output formatter + assert result["formatter_type"] == "djl_decorators" + assert "model_type" in result + + def test_xgboost_djl_env_with_formatters(self): + """Test XGBoost DJL compatibility with env variables""" + with Runner('cpu-full', 'xgboost_djl_env', download=True) as r: + env = [ + "SAGEMAKER_NUM_MODEL_WORKERS=2", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT=application/json" + ] + r.launch( + env_vars=env, + cmd= + "serve -m xgboost_djl_env::Python=file:/opt/ml/model/xgboost_djl_all_formatters_v3.zip" + ) + + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_djl_env", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 200 + result = response.json() + assert result["prediction"] == 0.999 + assert result["custom_xgb"] == True + assert result["formatter_type"] == "djl_decorators" + + def test_xgboost_djl_input_output_formatters(self): + """Test XGBoost DJL input_formatter and output_formatter decorators""" + with Runner('cpu-full', 'xgboost_djl_input_output', + download=True) as r: + r.launch( + cmd= + "serve -m xgboost_djl_io::Python=file:/opt/ml/model/xgboost_djl_input_output_v3.zip" + ) + + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_djl_io", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "original_prediction" in result + assert "doubled_prediction" in result + assert "prediction_sum" in result + assert result["djl_output_formatter_used"] == True + assert result["formatter_type"] == "djl_decorators" + # Verify doubled prediction is actually double the original + original = result["original_prediction"][0] + doubled = result["doubled_prediction"][0] + assert doubled == original * 2 + + def test_xgboost_djl_invalid_input_formatter(self): + """Test XGBoost DJL input_formatter that returns invalid format (should fail)""" + with Runner('cpu-full', 'xgboost_djl_invalid_input', + download=True) as r: + r.launch( + cmd= + "serve -m xgboost_djl_invalid::Python=file:/opt/ml/model/xgboost_djl_invalid_input_v3.zip" + ) + + # Test should fail because input_formatter returns raw list instead of numpy array + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_djl_invalid", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 424 # Failed dependency - input processing error diff --git a/tests/integration/test_sagemaker_compatibility.py b/tests/integration/test_sagemaker_compatibility.py new file mode 100644 index 000000000..cf066d179 --- /dev/null +++ b/tests/integration/test_sagemaker_compatibility.py @@ -0,0 +1,279 @@ +import os +import requests +from tests import Runner + + +@pytest.mark.cpu +class TestSageMakerCompatibility: + + def test_sagemaker_num_workers(self): + with Runner('cpu-full', 'sagemaker_num_workers', download=True) as r: + env = ["SAGEMAKER_NUM_MODEL_WORKERS=3"] + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + def test_sagemaker_max_request_size(self): + with Runner('cpu-full', 'sagemaker_max_request_size', + download=True) as r: + env = ["SAGEMAKER_MAX_REQUEST_SIZE=1024"] # 1KB + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + + # Test with large request (should fail - exceeds 1KB limit) + large_data = { + "inputs": + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] * + 50 # 50 samples + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=large_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 413 # Request Entity Too Large + + def test_sagemaker_startup_timeout(self): + with Runner('cpu-full', 'sagemaker_startup_timeout', + download=True) as r: + env = ["SAGEMAKER_STARTUP_TIMEOUT=300"] # 5 minutes + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + # Verify model loaded successfully (SageMaker compatibility working) + response = requests.get( + "http://localhost:8080/models/sklearn_test") + assert response.status_code == 200 + + # Verify timeout setting was applied by checking server logs + log_file = os.path.join(os.getcwd(), "logs", "serving.log") + print(f"Checking log file: {log_file}") + if os.path.exists(log_file): + print("Log file found, parsing content") + with open(log_file, 'r') as f: + log_content = f.read() + assert '"model_loading_timeout":"300' in log_content, "SAGEMAKER_STARTUP_TIMEOUT not found in logs" + + def test_sagemaker_predict_timeout(self): + with Runner('cpu-full', 'sagemaker_predict_timeout', + download=True) as r: + env = ["SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS=120"] # 2 minutes + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + # Verify model loaded successfully (SageMaker compatibility working) + response = requests.get( + "http://localhost:8080/models/sklearn_test") + assert response.status_code == 200 + + # Verify timeout setting was applied by checking server logs + log_file = os.path.join(os.getcwd(), "logs", "serving.log") + print(f"Checking log file: {log_file}") + if os.path.exists(log_file): + print("Log file found, parsing content") + with open(log_file, 'r') as f: + log_content = f.read() + assert '"predict_timeout":"120' in log_content, "SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS not found in logs" + + def test_sagemaker_model_server_vmargs(self): + with Runner('cpu-full', 'sagemaker_vmargs', download=True) as r: + env = [ + "SAGEMAKER_MODEL_SERVER_VMARGS=-Dsagemaker.test=true -Xmx2g", + "SAGEMAKER_NUM_MODEL_WORKERS=2" + ] + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + # Verify model loaded successfully + response = requests.get( + "http://localhost:8080/models/sklearn_test") + assert response.status_code == 200 + + # Verify JVM arguments were applied by checking server logs + log_file = os.path.join(os.getcwd(), "logs", "serving.log") + print(f"Checking log file: {log_file}") + if os.path.exists(log_file): + print("Log file found, parsing content") + with open(log_file, 'r') as f: + log_content = f.read() + # Check for our custom system property + assert "-Dsagemaker.test=true" in log_content, "SAGEMAKER_MODEL_SERVER_VMARGS system property not found in logs" + # Check for memory override + assert "-Xmx2g" in log_content, "SAGEMAKER_MODEL_SERVER_VMARGS memory setting not found in logs" + # Check that heap size was actually set to 2048MB + assert "Max heap size: 2048" in log_content, "JVM heap size was not overridden correctly" + else: + print("Log file not found, skipping JVM argument verification") + + # Test inference to ensure JVM args didn't break functionality + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + def test_sagemaker_default_invocations_accept(self): + with Runner('cpu-full', 'sagemaker_default_accept', + download=True) as r: + env = ["SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT=text/csv"] + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + # Test without Accept header - should return CSV due to default + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={"Content-Type": "application/json"}) + print(f"Response status: {response.status_code}") + print(f"Response headers: {dict(response.headers)}") + print(f"Response text: {response.text}") + assert response.status_code == 200 + assert "text/csv" in response.headers.get("Content-Type", "") + + def test_sagemaker_startup_timeout_failure(self): + """Test that SAGEMAKER_STARTUP_TIMEOUT causes model loading to fail when timeout is too short""" + with Runner('cpu-full', + 'sagemaker_startup_timeout_fail', + download=True) as r: + env = ["SAGEMAKER_STARTUP_TIMEOUT=3"] # 3 seconds timeout + try: + r.launch( + env_vars=env, + cmd= + "serve -m slow_model::Python=file:/opt/ml/model/slow_loading_model.zip" + ) + # If we get here, the server started but model loading should have failed + # Check that the model is not available + response = requests.get( + "http://localhost:8080/models/slow_model") + assert response.status_code == 404 # Model not found due to timeout + except Exception as e: + # Expected - server should fail to start or model should fail to load + print(f"Expected failure due to timeout: {e}") + assert True + + def test_sagemaker_predict_timeout_failure(self): + """Test that SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS causes prediction to timeout""" + with Runner('cpu-full', + 'sagemaker_predict_timeout_fail', + download=True) as r: + env = ["SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS=3" + ] # 3 seconds timeout + r.launch( + env_vars=env, + cmd= + "serve -m slow_predict::Python=file:/opt/ml/model/slow_predict_model.zip" + ) + + response = requests.get( + "http://localhost:8080/models/slow_predict") + assert response.status_code == 200 + + # But prediction should timeout (10 second predict_fn vs 3 second timeout) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/slow_predict", + json=test_data, + headers={"Content-Type": "application/json"}, + timeout=15 # Give enough time for the request itself + ) + # Should get timeout error from DJL serving + assert response.status_code in [408, 500, + 503] # Timeout or server error + + def test_sagemaker_max_payload_in_mb(self): + """Test SAGEMAKER_MAX_PAYLOAD_IN_MB conversion to bytes""" + with Runner('cpu-full', 'sagemaker_max_payload_mb', + download=True) as r: + env = ["SAGEMAKER_MAX_PAYLOAD_IN_MB=1"] # 1MB = 1048576 bytes + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + + # Test with large request (should fail - exceeds 1MB limit) + large_array = [1.0] * 50000 # 50k floats + large_data = { + "inputs": + [large_array] * 10 # 10 arrays of 50k floats each = ~20MB + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=large_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 413 # Request Entity Too Large + + def test_sagemaker_max_request_size_precedence(self): + """Test that SAGEMAKER_MAX_REQUEST_SIZE takes precedence over SAGEMAKER_MAX_PAYLOAD_IN_MB""" + with Runner('cpu-full', 'sagemaker_precedence', download=True) as r: + env = [ + "SAGEMAKER_MAX_REQUEST_SIZE=1024", + "SAGEMAKER_MAX_PAYLOAD_IN_MB=1" # 1MB = 1048576 bytes (should be ignored) + ] + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + + # Test with request larger than 1KB but smaller than 1MB + # Should fail because SAGEMAKER_MAX_REQUEST_SIZE=1024 takes precedence + medium_data = { + "inputs": + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] * + 50 # ~5KB payload + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=medium_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 413 # Should fail due to 1KB limit, not 1MB limit diff --git a/tests/integration/test_xgb_skl.py b/tests/integration/test_xgb_skl.py index f9f675539..87275c9dd 100644 --- a/tests/integration/test_xgb_skl.py +++ b/tests/integration/test_xgb_skl.py @@ -241,59 +241,6 @@ def test_xgboost_deprecated_format(self): result = response.json() assert "predictions" in result - # Custom formatter tests - def test_sklearn_custom_formatters(self): - with Runner('cpu-full', 'sklearn_custom', download=True) as r: - r.launch( - cmd= - "serve -m sklearn_custom::Python=file:/opt/ml/model/sklearn_custom_model_v2.zip" - ) - test_data = { - "features": - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] - } - response = requests.post( - "http://localhost:8080/predictions/sklearn_custom", - data=json.dumps(test_data), - headers={ - "Content-Type": "application/json", - "Accept": "application/json" - }) - assert response.status_code == 200 - result = response.json() - assert "result" in result - assert "confidence" in result - assert "model_type" in result - assert result["model_type"] == "sklearn_custom" - assert result["confidence"] == 0.95 - - def test_xgboost_custom_formatters(self): - with Runner('cpu-full', 'xgboost_custom', download=True) as r: - r.launch( - cmd= - "serve -m xgboost_custom::Python=file:/opt/ml/model/xgboost_custom_model_v2.zip" - ) - test_data = { - "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] - } - response = requests.post( - "http://localhost:8080/predictions/xgboost_custom", - data=json.dumps(test_data), - headers={ - "Content-Type": "application/json", - "Accept": "application/json" - }) - assert response.status_code == 200 - result = response.json() - assert "probability" in result - assert "prediction" in result - assert "model_version" in result - assert "processed_by" in result - assert result["model_version"] == "1.0" - assert result["processed_by"] == "xgboost_custom" - assert isinstance(result["probability"], float) - assert result["prediction"] in [0, 1] - # Error handling tests - CSV format errors def test_sklearn_csv_with_headers(self): with Runner('cpu-full', 'sklearn_csv_headers', download=True) as r: @@ -462,7 +409,12 @@ def test_xgboost_wrong_input_shape(self): cmd= "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" ) - test_data = {"inputs": [[1.0, 2.0, 3.0, 4.0, 5.0]]} + test_data = { + "inputs": [[ + 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, + 3.0, 4.0, 5.0 + ]] + } response = requests.post( "http://localhost:8080/predictions/xgboost_test", json=test_data, diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index ac464f1ba..ae675ee29 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -34,6 +34,7 @@ import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.serving.wlm.util.EventManager; +import ai.djl.serving.wlm.util.SageMakerCompatibility; import ai.djl.serving.wlm.util.WlmConfigManager; import ai.djl.serving.wlm.util.WlmOutOfMemoryException; import ai.djl.translate.NoopServingTranslatorFactory; @@ -888,6 +889,10 @@ private void loadServingProperties() { logger.warn("{}: Failed read serving.properties file", uid, e); } } + + // Apply SageMaker compatibility for model-level configurations + SageMakerCompatibility.applyModelCompatibility(prop); + // load default settings from env for (Map.Entry entry : Utils.getenv().entrySet()) { String key = entry.getKey(); @@ -906,6 +911,7 @@ private void loadServingProperties() { prop.putIfAbsent("engine", value); continue; } + logger.debug("{}: Setting model option {}={}", uid, key, value); prop.putIfAbsent("option." + key, value); } else if (key.startsWith("ARGS_")) { key = key.substring(5); diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/SageMakerCompatibility.java b/wlm/src/main/java/ai/djl/serving/wlm/util/SageMakerCompatibility.java new file mode 100644 index 000000000..9c196fd2e --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/SageMakerCompatibility.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; + +/** + * Utility class for SageMaker backwards compatibility with XGBoost and SKLearn containers. + * Translates SageMaker environment variables to DJL equivalents. + */ +public final class SageMakerCompatibility { + + private static final Logger logger = LoggerFactory.getLogger(SageMakerCompatibility.class); + + private SageMakerCompatibility() {} + + /** + * Applies SageMaker compatibility for server-level configurations. Called during ConfigManager + * initialization. + * + * @param properties the properties to modify + */ + public static void applyServerCompatibility(Properties properties) { + String maxRequestSize = System.getenv("SAGEMAKER_MAX_REQUEST_SIZE"); + String maxPayloadMb = System.getenv("SAGEMAKER_MAX_PAYLOAD_IN_MB"); + + // SAGEMAKER_MAX_REQUEST_SIZE takes precedence over SAGEMAKER_MAX_PAYLOAD_IN_MB + if (maxRequestSize != null) { + logger.info( + "SageMaker compatibility - translating SAGEMAKER_MAX_REQUEST_SIZE={} to" + + " max_request_size", + maxRequestSize); + properties.setProperty("max_request_size", maxRequestSize); + + if (maxPayloadMb != null) { + logger.warn( + "Both SAGEMAKER_MAX_REQUEST_SIZE and SAGEMAKER_MAX_PAYLOAD_IN_MB are set." + + " Using SAGEMAKER_MAX_REQUEST_SIZE={} and ignoring" + + " SAGEMAKER_MAX_PAYLOAD_IN_MB={}", + maxRequestSize, + maxPayloadMb); + } + } else if (maxPayloadMb != null) { + try { + long payloadBytes = Long.parseLong(maxPayloadMb) * 1024 * 1024; + logger.info( + "SageMaker compatibility - translating SAGEMAKER_MAX_PAYLOAD_IN_MB={} to" + + " max_request_size={} bytes", + maxPayloadMb, + payloadBytes); + properties.setProperty("max_request_size", String.valueOf(payloadBytes)); + } catch (NumberFormatException e) { + logger.warn("Invalid SAGEMAKER_MAX_PAYLOAD_IN_MB value: {}", maxPayloadMb); + } + } + } + + /** + * Applies SageMaker compatibility for model-level configurations. Called during + * ModelInfo.loadServingProperties(). + * + * @param properties the properties to modify + */ + public static void applyModelCompatibility(Properties properties) { + String numWorkers = System.getenv("SAGEMAKER_NUM_MODEL_WORKERS"); + if (numWorkers != null) { + logger.info( + "SageMaker compatibility - translating SAGEMAKER_NUM_MODEL_WORKERS={} to" + + " minWorkers/maxWorkers", + numWorkers); + properties.setProperty("minWorkers", numWorkers); + properties.setProperty("maxWorkers", numWorkers); + } + + String startupTimeout = System.getenv("SAGEMAKER_STARTUP_TIMEOUT"); + if (startupTimeout != null) { + logger.info( + "SageMaker compatibility - translating SAGEMAKER_STARTUP_TIMEOUT={} to" + + " model_loading_timeout", + startupTimeout); + properties.setProperty("option.model_loading_timeout", startupTimeout); + } + + String predictTimeout = System.getenv("SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS"); + if (predictTimeout != null) { + logger.info( + "SageMaker compatibility - translating" + + " SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS={} to predict_timeout", + predictTimeout); + properties.setProperty("option.predict_timeout", predictTimeout); + } + + String defaultAccept = System.getenv("SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT"); + if (defaultAccept != null) { + String entryPoint = properties.getProperty("option.entryPoint"); + if (isSklearnOrXgboostHandler(entryPoint)) { + logger.info( + "SageMaker compatibility - setting default accept type to {}", + defaultAccept); + properties.setProperty("option.default_accept", defaultAccept); + } + } + } + + private static boolean isSklearnOrXgboostHandler(String entryPoint) { + return entryPoint != null + && ("djl_python.sklearn_handler".equals(entryPoint) + || "djl_python.xgboost_handler".equals(entryPoint)); + } +} From a3638e084bd68f27ef1def04d82023ead66e4508 Mon Sep 17 00:00:00 2001 From: smouaa Date: Wed, 29 Oct 2025 09:22:26 +0000 Subject: [PATCH 2/3] Encapsulated custom handler/formatter logic and added SageMaker endpoint integration tests --- .gitignore | 2 + .../djl_python/custom_formatter_handling.py | 106 ++++- .../python/setup/djl_python/input_parser.py | 12 +- .../setup/djl_python/sklearn_handler.py | 85 +--- .../setup/djl_python/xgboost_handler.py | 85 +--- tests/integration/download_models.sh | 14 +- .../run_all_ml_sm_endpoint_tests.sh | 87 ++++ .../sagemaker-ml-endpoint-tests.py | 437 ++++++++++++++++++ tests/integration/test_custom_formatters.py | 23 +- .../test_sagemaker_compatibility.py | 1 + 10 files changed, 690 insertions(+), 162 deletions(-) create mode 100755 tests/integration/run_all_ml_sm_endpoint_tests.sh create mode 100644 tests/integration/sagemaker-ml-endpoint-tests.py diff --git a/.gitignore b/.gitignore index d42e7d689..953adc42c 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ node_modules/ # dir tests/integration/models/ +tests/integration/sagemaker_test_results.txt engines/python/setup/djl_python/tests/resources* tests/integration/awscurl @@ -39,3 +40,4 @@ __pycache__ dist/ *.egg-info/ *.pt + diff --git a/engines/python/setup/djl_python/custom_formatter_handling.py b/engines/python/setup/djl_python/custom_formatter_handling.py index ced29c1c9..9e41f6b07 100644 --- a/engines/python/setup/djl_python/custom_formatter_handling.py +++ b/engines/python/setup/djl_python/custom_formatter_handling.py @@ -11,8 +11,12 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. import logging +import os +from dataclasses import dataclass +from typing import Optional, Callable from djl_python.service_loader import get_annotated_function +from djl_python.utils import get_sagemaker_function logger = logging.getLogger(__name__) @@ -26,31 +30,103 @@ def __init__(self, message: str, original_exception: Exception): self.__cause__ = original_exception +@dataclass +class CustomFormatters: + """Container for input/output formatting functions""" + input_formatter: Optional[Callable] = None + output_formatter: Optional[Callable] = None + + +@dataclass +class CustomHandlers: + """Container for prediction/initialization handler functions""" + prediction_handler: Optional[Callable] = None + init_handler: Optional[Callable] = None + + +@dataclass +class CustomCode: + """Container for all custom formatters and handlers""" + formatters: CustomFormatters + handlers: CustomHandlers + is_sagemaker_script: bool = False + + def __init__(self): + self.formatters = CustomFormatters() + self.handlers = CustomHandlers() + self.is_sagemaker_script = False + + class CustomFormatterHandler: def __init__(self): - self.output_formatter = None - self.input_formatter = None + self.custom_code = CustomCode() - def load_formatters(self, model_dir: str): - """Load custom formatters from model.py""" + def load_formatters(self, model_dir: str) -> CustomCode: + """Load custom formatters/handlers from model.py with SageMaker detection""" try: - self.input_formatter = get_annotated_function( + self.custom_code.formatters.input_formatter = get_annotated_function( model_dir, "is_input_formatter") - self.output_formatter = get_annotated_function( + self.custom_code.formatters.output_formatter = get_annotated_function( model_dir, "is_output_formatter") + self.custom_code.handlers.prediction_handler = get_annotated_function( + model_dir, "is_prediction_handler") + self.custom_code.handlers.init_handler = get_annotated_function( + model_dir, "is_init_handler") + + # Detect SageMaker script pattern for backward compatibility + self._detect_sagemaker_functions(model_dir) + logger.info( - f"Loaded formatters - input: {self.input_formatter}, output: {self.output_formatter}" + f"Loaded formatters - input: {bool(self.custom_code.formatters.input_formatter)}, " + f"output: {bool(self.custom_code.formatters.output_formatter)}" ) + logger.info( + f"Loaded handlers - prediction: {bool(self.custom_code.handlers.prediction_handler)}, " + f"init: {bool(self.custom_code.handlers.init_handler)}, " + f"sagemaker: {self.custom_code.is_sagemaker_script}") + return self.custom_code except Exception as e: raise CustomFormatterError( - f"Failed to load custom formatters from {model_dir}", e) + f"Failed to load custom code from {model_dir}", e) + + def _detect_sagemaker_functions(self, model_dir: str): + """Detect and load SageMaker-style functions for backward compatibility""" + # If no decorator-based code found, check for SageMaker functions + if not any([ + self.custom_code.formatters.input_formatter, + self.custom_code.formatters.output_formatter, + self.custom_code.handlers.prediction_handler, + self.custom_code.handlers.init_handler + ]): + sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn') + sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn') + sagemaker_predict_fn = get_sagemaker_function( + model_dir, 'predict_fn') + sagemaker_output_fn = get_sagemaker_function( + model_dir, 'output_fn') + + if any([ + sagemaker_model_fn, sagemaker_input_fn, + sagemaker_predict_fn, sagemaker_output_fn + ]): + self.custom_code.is_sagemaker_script = True + if sagemaker_model_fn: + self.custom_code.handlers.init_handler = sagemaker_model_fn + if sagemaker_input_fn: + self.custom_code.formatters.input_formatter = sagemaker_input_fn + if sagemaker_predict_fn: + self.custom_code.handlers.prediction_handler = sagemaker_predict_fn + if sagemaker_output_fn: + self.custom_code.formatters.output_formatter = sagemaker_output_fn + logger.info("Loaded SageMaker-style functions") def apply_input_formatter(self, decoded_payload, **kwargs): """Apply input formatter if available""" - if self.input_formatter: + if self.custom_code.formatters.input_formatter: try: - return self.input_formatter(decoded_payload, **kwargs) + return self.custom_code.formatters.input_formatter( + decoded_payload, **kwargs) except Exception as e: logger.exception("Custom input formatter failed") raise CustomFormatterError( @@ -59,9 +135,9 @@ def apply_input_formatter(self, decoded_payload, **kwargs): def apply_output_formatter(self, output): """Apply output formatter if available""" - if self.output_formatter: + if self.custom_code.formatters.output_formatter: try: - return self.output_formatter(output) + return self.custom_code.formatters.output_formatter(output) except Exception as e: logger.exception("Custom output formatter failed") raise CustomFormatterError( @@ -79,3 +155,9 @@ async def apply_output_formatter_streaming_raw(self, stream_generator): logger.exception("Streaming formatter failed") raise CustomFormatterError( "Custom streaming formatter execution failed", e) + + +def load_custom_code(model_dir: str) -> CustomCode: + """Load custom code, checking DJL decorators first, then SageMaker functions""" + handler = CustomFormatterHandler() + return handler.load_formatters(model_dir) diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index bc54066ec..62cf5ea42 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -36,25 +36,25 @@ def input_formatter(function): return function -def predict_formatter(function): +def prediction_handler(function): """ - Decorator for predict_formatter. User just need to annotate @predict_formatter for their custom defined function. + Decorator for prediction_handler. User just need to annotate @prediction_handler for their custom defined function. :param function: Decorator takes in the function and adds an attribute. :return: """ # adding an attribute to the function, which is used to find the decorated function. - function.is_predict_formatter = True + function.is_prediction_handler = True return function -def model_loading_formatter(function): +def init_handler(function): """ - Decorator for model_loading_formatter. User just need to annotate @model_loading_formatter for their custom defined function. + Decorator for init_handler. User just need to annotate @init_handler for their custom defined function. :param function: Decorator takes in the function and adds an attribute. :return: """ # adding an attribute to the function, which is used to find the decorated function. - function.is_model_loading_formatter = True + function.is_init_handler = True return function diff --git a/engines/python/setup/djl_python/sklearn_handler.py b/engines/python/setup/djl_python/sklearn_handler.py index 1fd66027f..4e0393dd0 100644 --- a/engines/python/setup/djl_python/sklearn_handler.py +++ b/engines/python/setup/djl_python/sklearn_handler.py @@ -18,8 +18,8 @@ from typing import Optional from djl_python import Input, Output from djl_python.encode_decode import decode -from djl_python.utils import find_model_file, get_sagemaker_function -from djl_python.service_loader import get_annotated_function +from djl_python.utils import find_model_file +from djl_python.custom_formatter_handling import load_custom_code from djl_python.import_utils import joblib, cloudpickle, skops_io as sio @@ -28,52 +28,8 @@ class SklearnHandler: def __init__(self): self.model = None self.initialized = False - self.custom_input_formatter = None - self.custom_output_formatter = None - self.custom_predict_formatter = None - self.custom_model_loading_formatter = None + self.custom_code = None self.init_properties = None - self.is_sagemaker_script = False - - def _load_custom_formatters(self, model_dir: str): - """Load custom formatters, checking DJL decorators first, then SageMaker functions.""" - # Check for DJL decorator-based custom formatters first - self.custom_model_loading_formatter = get_annotated_function( - model_dir, "is_model_loading_formatter") - self.custom_input_formatter = get_annotated_function( - model_dir, "is_input_formatter") - self.custom_output_formatter = get_annotated_function( - model_dir, "is_output_formatter") - self.custom_predict_formatter = get_annotated_function( - model_dir, "is_predict_formatter") - - # If no decorator-based formatters found, check for SageMaker-style formatters - if not any([ - self.custom_input_formatter, self.custom_output_formatter, - self.custom_predict_formatter, - self.custom_model_loading_formatter - ]): - - sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn') - sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn') - sagemaker_predict_fn = get_sagemaker_function( - model_dir, 'predict_fn') - sagemaker_output_fn = get_sagemaker_function( - model_dir, 'output_fn') - - if any([ - sagemaker_model_fn, sagemaker_input_fn, - sagemaker_predict_fn, sagemaker_output_fn - ]): - self.is_sagemaker_script = True - if sagemaker_model_fn: - self.custom_model_loading_formatter = sagemaker_model_fn - if sagemaker_input_fn: - self.custom_input_formatter = sagemaker_input_fn - if sagemaker_predict_fn: - self.custom_predict_formatter = sagemaker_predict_fn - if sagemaker_output_fn: - self.custom_output_formatter = sagemaker_output_fn def _get_trusted_types(self, properties: dict): trusted_types_str = properties.get("skops_trusted_types", "") @@ -107,12 +63,12 @@ def initialize(self, properties: dict): f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle" ) - # Load custom formatters - self._load_custom_formatters(model_dir) + # Load custom code + self.custom_code = load_custom_code(model_dir) # Load model - if self.custom_model_loading_formatter: - self.model = self.custom_model_loading_formatter(model_dir) + if self.custom_code.handlers.init_handler: + self.model = self.custom_code.handlers.init_handler(model_dir) else: model_file = find_model_file(model_dir, extensions) if not model_file: @@ -154,7 +110,7 @@ def inference(self, inputs: Input) -> Output: accept = default_accept # Validate accept type (skip validation if custom output formatter is provided) - if not self.custom_output_formatter: + if not self.custom_code.formatters.output_formatter: supported_accept_types = ["application/json", "text/csv"] if not any(supported_type in accept for supported_type in supported_accept_types): @@ -164,12 +120,12 @@ def inference(self, inputs: Input) -> Output: # Input processing X = None - if self.custom_input_formatter: - if self.is_sagemaker_script: - X = self.custom_input_formatter(inputs.get_as_bytes(), - content_type) + if self.custom_code.formatters.input_formatter: + if self.custom_code.is_sagemaker_script: + X = self.custom_code.formatters.input_formatter( + inputs.get_as_bytes(), content_type) else: - X = self.custom_input_formatter(inputs) + X = self.custom_code.formatters.input_formatter(inputs) elif "text/csv" in content_type: X = decode(inputs, content_type, require_csv_headers=False) else: @@ -185,19 +141,22 @@ def inference(self, inputs: Input) -> Output: if X.ndim == 1: X = X.reshape(1, -1) - if self.custom_predict_formatter: - predictions = self.custom_predict_formatter(X, self.model) + if self.custom_code.handlers.prediction_handler: + predictions = self.custom_code.handlers.prediction_handler( + X, self.model) else: predictions = self.model.predict(X) # Output processing outputs = Output() - if self.custom_output_formatter: - if self.is_sagemaker_script: - data = self.custom_output_formatter(predictions, accept) + if self.custom_code.formatters.output_formatter: + if self.custom_code.is_sagemaker_script: + data = self.custom_code.formatters.output_formatter( + predictions, accept) outputs.add_property("Content-Type", accept) else: - data = self.custom_output_formatter(predictions) + data = self.custom_code.formatters.output_formatter( + predictions) outputs.add(data) elif "text/csv" in accept: csv_buffer = StringIO() diff --git a/engines/python/setup/djl_python/xgboost_handler.py b/engines/python/setup/djl_python/xgboost_handler.py index 483b2e939..23244ff06 100644 --- a/engines/python/setup/djl_python/xgboost_handler.py +++ b/engines/python/setup/djl_python/xgboost_handler.py @@ -18,8 +18,8 @@ from typing import Optional from djl_python import Input, Output from djl_python.encode_decode import decode -from djl_python.utils import find_model_file, get_sagemaker_function -from djl_python.service_loader import get_annotated_function +from djl_python.utils import find_model_file +from djl_python.custom_formatter_handling import load_custom_code from djl_python.import_utils import xgboost as xgb @@ -28,52 +28,8 @@ class XGBoostHandler: def __init__(self): self.model = None self.initialized = False - self.custom_input_formatter = None - self.custom_output_formatter = None - self.custom_predict_formatter = None - self.custom_model_loading_formatter = None + self.custom_code = None self.init_properties = None - self.is_sagemaker_script = False - - def _load_custom_formatters(self, model_dir: str): - """Load custom formatters, checking DJL decorators first, then SageMaker functions.""" - # Check for DJL decorator-based custom formatters first - self.custom_model_loading_formatter = get_annotated_function( - model_dir, "is_model_loading_formatter") - self.custom_input_formatter = get_annotated_function( - model_dir, "is_input_formatter") - self.custom_output_formatter = get_annotated_function( - model_dir, "is_output_formatter") - self.custom_predict_formatter = get_annotated_function( - model_dir, "is_predict_formatter") - - # If no decorator-based formatters found, check for SageMaker-style formatters - if not any([ - self.custom_input_formatter, self.custom_output_formatter, - self.custom_predict_formatter, - self.custom_model_loading_formatter - ]): - - sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn') - sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn') - sagemaker_predict_fn = get_sagemaker_function( - model_dir, 'predict_fn') - sagemaker_output_fn = get_sagemaker_function( - model_dir, 'output_fn') - - if any([ - sagemaker_model_fn, sagemaker_input_fn, - sagemaker_predict_fn, sagemaker_output_fn - ]): - self.is_sagemaker_script = True - if sagemaker_model_fn: - self.custom_model_loading_formatter = sagemaker_model_fn - if sagemaker_input_fn: - self.custom_input_formatter = sagemaker_input_fn - if sagemaker_predict_fn: - self.custom_predict_formatter = sagemaker_predict_fn - if sagemaker_output_fn: - self.custom_output_formatter = sagemaker_output_fn def initialize(self, properties: dict): # Store initialization properties for use during inference @@ -95,11 +51,11 @@ def initialize(self, properties: dict): f"Unsupported model format: {model_format}. Supported formats: json, ubj, pickle, xgb" ) - # Load custom formatters - self._load_custom_formatters(model_dir) + # Load custom code + self.custom_code = load_custom_code(model_dir) - if self.custom_model_loading_formatter: - self.model = self.custom_model_loading_formatter(model_dir) + if self.custom_code.handlers.init_handler: + self.model = self.custom_code.handlers.init_handler(model_dir) else: model_file = find_model_file(model_dir, extensions) if not model_file: @@ -142,7 +98,7 @@ def inference(self, inputs: Input) -> Output: accept = default_accept # Validate accept type (skip validation if custom output formatter is provided) - if not self.custom_output_formatter: + if not self.custom_code.formatters.output_formatter: supported_accept_types = ["application/json", "text/csv"] if not any(supported_type in accept for supported_type in supported_accept_types): @@ -152,12 +108,12 @@ def inference(self, inputs: Input) -> Output: # Input processing X = None - if self.custom_input_formatter: - if self.is_sagemaker_script: - X = self.custom_input_formatter(inputs.get_as_bytes(), - content_type) + if self.custom_code.formatters.input_formatter: + if self.custom_code.is_sagemaker_script: + X = self.custom_code.formatters.input_formatter( + inputs.get_as_bytes(), content_type) else: - X = self.custom_input_formatter(inputs) + X = self.custom_code.formatters.input_formatter(inputs) elif "text/csv" in content_type: X = decode(inputs, content_type, require_csv_headers=False) else: @@ -172,20 +128,23 @@ def inference(self, inputs: Input) -> Output: if X.ndim == 1: X = X.reshape(1, -1) - if self.custom_predict_formatter: - predictions = self.custom_predict_formatter(X, self.model) + if self.custom_code.handlers.prediction_handler: + predictions = self.custom_code.handlers.prediction_handler( + X, self.model) else: dmatrix = xgb.DMatrix(X) predictions = self.model.predict(dmatrix) # Output processing outputs = Output() - if self.custom_output_formatter: - if self.is_sagemaker_script: - data = self.custom_output_formatter(predictions, accept) + if self.custom_code.formatters.output_formatter: + if self.custom_code.is_sagemaker_script: + data = self.custom_code.formatters.output_formatter( + predictions, accept) outputs.add_property("Content-Type", accept) else: - data = self.custom_output_formatter(predictions) + data = self.custom_code.formatters.output_formatter( + predictions) outputs.add(data) elif "text/csv" in accept: diff --git a/tests/integration/download_models.sh b/tests/integration/download_models.sh index f7d75dffd..b8804bd39 100755 --- a/tests/integration/download_models.sh +++ b/tests/integration/download_models.sh @@ -35,11 +35,11 @@ python_skl_models_urls=( "https://resources.djl.ai/test-models/python/sklearn/sklearn_unsafe_model_v2.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_v2.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_skops_model_env_v2.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_sm.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output_invalid.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_mixed_djl_sagemaker.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_all_formatters_v3.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_sm_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output_invalid_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_mixed_djl_sagemaker_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_all_formatters_v4.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_input_output_v3.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_invalid_input_v3.zip" "https://resources.djl.ai/test-models/python/sklearn/slow_loading_model.zip" @@ -55,8 +55,8 @@ python_xgb_models_urls=( "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_all.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_input_output.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_input_output_invalid.zip" - "https://resources.djl.ai/test-models/python/xgboost/xgboost_mixed_djl_sagemaker.zip" - "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_all_formatters_v3.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_mixed_djl_sagemaker_v2.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_all_formatters.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_input_output_v3.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_invalid_input_v3.zip" ) diff --git a/tests/integration/run_all_ml_sm_endpoint_tests.sh b/tests/integration/run_all_ml_sm_endpoint_tests.sh new file mode 100755 index 000000000..9b276d47d --- /dev/null +++ b/tests/integration/run_all_ml_sm_endpoint_tests.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Test all model and format permutations +IMAGE_TYPE="candidate" +OUTPUT_FILE="sagemaker_test_results.txt" + +# Clear previous results +> "$OUTPUT_FILE" + +echo "Starting SageMaker ML endpoint tests at $(date)" | tee -a "$OUTPUT_FILE" +echo "Image type: $IMAGE_TYPE" | tee -a "$OUTPUT_FILE" +echo "========================================" | tee -a "$OUTPUT_FILE" + +# All models to test +MODELS=( + "sklearn-sagemaker-formatters" + "sklearn-djl-formatters" + "sklearn-skops-basic" + "xgboost-sagemaker-formatters" + "xgboost-djl-formatters" + "xgboost-basic" +) + +# Multi-model tests +MULTI_MODELS=( + "sklearn-multi" +) + +for model in "${MODELS[@]}"; do + echo "" | tee -a "$OUTPUT_FILE" + echo "Testing: $model with JSON and CSV" | tee -a "$OUTPUT_FILE" + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" + + if python3 sagemaker-ml-endpoint-tests.py "$model" "$IMAGE_TYPE" --test-both 2>&1 | tee -a "$OUTPUT_FILE"; then + if grep -q "Successfully tested" "$OUTPUT_FILE"; then + echo "SUCCESS: $model (JSON + CSV)" | tee -a "$OUTPUT_FILE" + else + echo "FAILED: $model (JSON + CSV) - No success message found" | tee -a "$OUTPUT_FILE" + fi + else + echo "FAILED: $model (JSON + CSV) - Script returned error" | tee -a "$OUTPUT_FILE" + fi + + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" +done + +# Test batch predictions +for model in "${MODELS[@]}"; do + echo "" | tee -a "$OUTPUT_FILE" + echo "Testing: $model with batch predictions" | tee -a "$OUTPUT_FILE" + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" + + if python3 sagemaker-ml-endpoint-tests.py "$model" "$IMAGE_TYPE" --test-batch --test-json 2>&1 | tee -a "$OUTPUT_FILE"; then + if grep -q "Successfully tested" "$OUTPUT_FILE"; then + echo "SUCCESS: $model (batch)" | tee -a "$OUTPUT_FILE" + else + echo "FAILED: $model (batch) - No success message found" | tee -a "$OUTPUT_FILE" + fi + else + echo "FAILED: $model (batch) - Script returned error" | tee -a "$OUTPUT_FILE" + fi + + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" +done + +# Test multi-model endpoints +for model in "${MULTI_MODELS[@]}"; do + echo "" | tee -a "$OUTPUT_FILE" + echo "Testing: $model multi-model endpoint" | tee -a "$OUTPUT_FILE" + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" + + if python3 sagemaker-ml-endpoint-tests.py "$model" "$IMAGE_TYPE" --test-multi-model --test-json 2>&1 | tee -a "$OUTPUT_FILE"; then + if grep -q "Successfully tested" "$OUTPUT_FILE"; then + echo "SUCCESS: $model (multi-model)" | tee -a "$OUTPUT_FILE" + else + echo "FAILED: $model (multi-model) - No success message found" | tee -a "$OUTPUT_FILE" + fi + else + echo "FAILED: $model (multi-model) - Script returned error" | tee -a "$OUTPUT_FILE" + fi + + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" +done + +echo "" | tee -a "$OUTPUT_FILE" +echo "All tests completed at $(date)" | tee -a "$OUTPUT_FILE" +echo "Results saved to: $OUTPUT_FILE" \ No newline at end of file diff --git a/tests/integration/sagemaker-ml-endpoint-tests.py b/tests/integration/sagemaker-ml-endpoint-tests.py new file mode 100644 index 000000000..3fa975c5b --- /dev/null +++ b/tests/integration/sagemaker-ml-endpoint-tests.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 + +import sagemaker +import boto3 +import json +from sagemaker import Model, Predictor +from sagemaker.utils import unique_name_from_base +from sagemaker.serializers import JSONSerializer, CSVSerializer +from sagemaker.deserializers import JSONDeserializer, CSVDeserializer +from sagemaker.multidatamodel import MultiDataModel +from argparse import ArgumentParser + +ROLE = "arn:aws:iam::185921645874:role/AmazonSageMaker-ExeuctionRole-IntegrationTests" +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" +DEFAULT_BUCKET = "sm-integration-tests-rubikon-usw2" + +# DJL Serving CPU images +CANDIDATE_IMAGES = { + "cpu-full": + "125045733377.dkr.ecr.us-west-2.amazonaws.com/djl-serving-cpu-full-test:latest" +} + +# Test configurations using S3 URIs +SKLEARN_CONFIGS = { + "sklearn-sagemaker-formatters": { + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/sklearn_custom_model_sm_v2.tar", + "payload": { + "features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "batch_payload": { + "features": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0", + "csv_batch_payload": + "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0\n3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0", + "env_vars": { + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + }, + "sklearn-djl-formatters": { + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/sklearn_djl_all_formatters_v4.tar", + "payload": { + "features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "batch_payload": { + "features": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0", + "csv_batch_payload": + "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0\n3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0", + "env_vars": { + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + }, + "sklearn-skops-basic": { + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/sklearn_skops_model.tar", + "payload": { + "inputs": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "batch_payload": { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0", + "csv_batch_payload": + "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0\n3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0", + "env_vars": { + "OPTION_SKOPS_TRUSTED_TYPES": + "sklearn.ensemble._forest.RandomForestClassifier", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + } +} + +XGBOOST_CONFIGS = { + "xgboost-sagemaker-formatters": { + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/xgboost_sagemaker_all.tar", + "payload": { + "features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "batch_payload": { + "features": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0", + "csv_batch_payload": + "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0\n3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0", + "env_vars": { + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + }, + "xgboost-djl-formatters": { + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/xgboost_djl_all_formatters.tar", + "payload": { + "features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "batch_payload": { + "features": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0", + "csv_batch_payload": + "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0\n3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0", + "env_vars": { + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + }, + "xgboost-basic": { + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/xgboost_model.tar", + "payload": { + "inputs": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "batch_payload": { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0", + "csv_batch_payload": + "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0\n3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0", + "env_vars": { + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + } +} + +# Multi-model endpoint configuration +MULTI_MODEL_CONFIGS = { + "sklearn-multi": { + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/sklearn_multi_model_v2/", + "models": { + "model1": { + "payload": { + "inputs": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + }, + "model2": { + "payload": { + "inputs": + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0] + }, + "csv_payload": "2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0" + } + }, + "env_vars": { + "OPTION_SKOPS_TRUSTED_TYPES": + "sklearn.ensemble._forest.RandomForestClassifier", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + } +} + + +def parse_args(): + parser = ArgumentParser( + description="Deploy sklearn/xgboost models to SageMaker endpoints") + parser.add_argument("model_name", help="Model configuration to use") + parser.add_argument( + "image_type", + help="Image type (nightly, candidate, or full ECR URI)", + nargs='?', + default="nightly") + parser.add_argument("-f", + "--framework", + help="Framework to test", + choices=["sklearn", "xgboost"]) + + format_group = parser.add_mutually_exclusive_group() + format_group.add_argument("--test-json", + help="Test JSON input/output only", + action="store_true") + format_group.add_argument("--test-csv", + help="Test CSV input/output only", + action="store_true") + format_group.add_argument("--test-both", + help="Test both JSON and CSV", + action="store_true") + + parser.add_argument("--test-batch", + help="Test batch predictions", + action="store_true") + parser.add_argument("--test-multi-model", + help="Test multi-model endpoint", + action="store_true") + + return parser.parse_args() + + +def get_image_uri(image_type): + if image_type == 'nightly': + return NIGHTLY_IMAGES["cpu-full"] + elif image_type == 'candidate': + return CANDIDATE_IMAGES["cpu-full"] + elif '.dkr.ecr.' in image_type or image_type.startswith( + 'deepjavalibrary/'): + return image_type + else: + raise ValueError( + f"Unknown image type: {image_type}. Use 'nightly', 'candidate', or full ECR URI" + ) + + +def get_sagemaker_session(default_bucket=DEFAULT_BUCKET, + default_bucket_prefix=None): + return sagemaker.session.Session( + boto3.session.Session(region_name='us-west-2'), + default_bucket=default_bucket, + default_bucket_prefix=default_bucket_prefix) + + +def get_name_for_resource(name): + cleaned_name = ''.join(filter(str.isalnum, name)) + base_name = f"sm-ml-test-{cleaned_name}" + return unique_name_from_base(base_name) + + +def test_endpoint(framework, + model_name, + image_type, + test_format="json", + test_batch=False): + """Test sklearn/xgboost model on SageMaker endpoint""" + configs = SKLEARN_CONFIGS if framework == "sklearn" else XGBOOST_CONFIGS + config = configs[model_name] + + session = get_sagemaker_session( + default_bucket_prefix=get_name_for_resource( + f"{framework}-{model_name}")) + + predictor = None + model = None + + try: + # Add SageMaker environment variables if specified + env_vars = config.get("env_vars", {}) + + model = Model(image_uri=get_image_uri(image_type), + model_data=config["model_data"], + role=ROLE, + name=get_name_for_resource(f"{framework}-{model_name}"), + env=env_vars, + sagemaker_session=session) + + print(f"Deploying {framework} endpoint for {model_name}...") + if env_vars: + print(f"Using environment variables: {env_vars}") + + endpoint_name = get_name_for_resource( + f"{framework}-{model_name}-endpoint") + model.deploy( + initial_instance_count=1, + instance_type=DEFAULT_INSTANCE_TYPE, + endpoint_name=endpoint_name, + ) + + predictor = Predictor(endpoint_name=endpoint_name, + sagemaker_session=session, + serializer=JSONSerializer(), + deserializer=JSONDeserializer()) + + if test_format in ["json", "both"]: + print("Testing JSON prediction...") + result = predictor.predict(config["payload"]) + print(f"JSON Result: {result}") + + if test_batch and "batch_payload" in config: + print("Testing JSON batch prediction...") + batch_result = predictor.predict(config["batch_payload"]) + print(f"JSON Batch Result: {batch_result}") + + if test_format in ["csv", "both"] and "csv_payload" in config: + print("Testing CSV prediction...") + predictor.serializer = CSVSerializer() + predictor.deserializer = CSVDeserializer() + csv_result = predictor.predict(config["csv_payload"]) + print(f"CSV Result: {csv_result}") + + if test_batch and "csv_batch_payload" in config: + print("Testing CSV batch prediction...") + csv_batch_result = predictor.predict( + config["csv_batch_payload"]) + print(f"CSV Batch Result: {csv_batch_result}") + + batch_msg = " with batch" if test_batch else "" + print( + f"✓ Successfully tested {framework} model: {model_name}{batch_msg}" + ) + + except Exception as e: + print(f"✗ Error testing {framework} model {model_name}: {e}") + raise e + finally: + if predictor: + predictor.delete_endpoint() + if model: + model.delete_model() + + +def test_multi_model_endpoint(model_name, image_type, test_format="json"): + """Test multi-model endpoint""" + config = MULTI_MODEL_CONFIGS[model_name] + + session = get_sagemaker_session( + default_bucket_prefix=get_name_for_resource(f"multi-{model_name}")) + + predictor = None + model = None + + try: + env_vars = config.get("env_vars", {}) + + # Use MultiDataModel for multi-model endpoints + model_s3_folder = config["model_data"].replace(".tar", "/") + model = MultiDataModel( + name=get_name_for_resource(f"multi-{model_name}"), + model_data_prefix=model_s3_folder, + image_uri=get_image_uri(image_type), + role=ROLE, + env=env_vars, + sagemaker_session=session) + + print(f"Deploying multi-model endpoint for {model_name}...") + if env_vars: + print(f"Using environment variables: {env_vars}") + + endpoint_name = get_name_for_resource(f"multi-{model_name}-endpoint") + model.deploy( + initial_instance_count=1, + instance_type=DEFAULT_INSTANCE_TYPE, + endpoint_name=endpoint_name, + ) + + predictor = Predictor(endpoint_name=endpoint_name, + sagemaker_session=session, + serializer=JSONSerializer(), + deserializer=JSONDeserializer()) + + # Test each model in the multi-model endpoint + for model_id, model_config in config["models"].items(): + print(f"Testing model: {model_id}") + + if test_format in ["json", "both"]: + print(f"Testing JSON prediction for {model_id}...") + # Add model target header for multi-model endpoint + result = predictor.predict( + model_config["payload"], + initial_args={"TargetModel": f"{model_id}.tar"}) + print(f"JSON Result for {model_id}: {result}") + + if test_format in ["csv", "both" + ] and "csv_payload" in model_config: + print(f"Testing CSV prediction for {model_id}...") + predictor.serializer = CSVSerializer() + predictor.deserializer = CSVDeserializer() + csv_result = predictor.predict( + model_config["csv_payload"], + initial_args={"TargetModel": f"{model_id}.tar"}) + print(f"CSV Result for {model_id}: {csv_result}") + # Reset to JSON for next model + predictor.serializer = JSONSerializer() + predictor.deserializer = JSONDeserializer() + + print(f"✓ Successfully tested multi-model endpoint: {model_name}") + + except Exception as e: + print(f"✗ Error testing multi-model endpoint {model_name}: {e}") + raise e + finally: + if predictor: + predictor.delete_endpoint() + if model: + model.delete_model() + + +if __name__ == "__main__": + args = parse_args() + + # Determine test format + if args.test_csv: + test_format = "csv" + elif args.test_both: + test_format = "both" + else: + test_format = "json" # Default + + # Handle multi-model endpoint testing + if args.test_multi_model: + if args.model_name not in MULTI_MODEL_CONFIGS: + raise ValueError( + f"Unknown multi-model config: {args.model_name}. Available: {list(MULTI_MODEL_CONFIGS.keys())}" + ) + test_multi_model_endpoint(args.model_name, args.image_type, + test_format) + else: + # Auto-detect framework if not specified + if args.framework: + configs = SKLEARN_CONFIGS if args.framework == "sklearn" else XGBOOST_CONFIGS + if args.model_name not in configs: + raise ValueError( + f"Unknown {args.framework} model: {args.model_name}. Available: {list(configs.keys())}" + ) + test_endpoint(args.framework, args.model_name, args.image_type, + test_format, args.test_batch) + else: + if args.model_name in SKLEARN_CONFIGS: + test_endpoint("sklearn", args.model_name, args.image_type, + test_format, args.test_batch) + elif args.model_name in XGBOOST_CONFIGS: + test_endpoint("xgboost", args.model_name, args.image_type, + test_format, args.test_batch) + else: + raise ValueError( + f"Unknown model: {args.model_name}. Available sklearn: {list(SKLEARN_CONFIGS.keys())}, xgboost: {list(XGBOOST_CONFIGS.keys())}, multi-model: {list(MULTI_MODEL_CONFIGS.keys())}" + ) diff --git a/tests/integration/test_custom_formatters.py b/tests/integration/test_custom_formatters.py index fa1a022e7..027e7d1bf 100644 --- a/tests/integration/test_custom_formatters.py +++ b/tests/integration/test_custom_formatters.py @@ -3,6 +3,7 @@ import tempfile import shutil import zipfile +import pytest from tests import Runner @@ -15,7 +16,7 @@ def test_sklearn_all_custom_formatters(self): download=True) as r: r.launch( cmd= - "serve -m sklearn_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm.zip" + "serve -m sklearn_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm_v2.zip" ) # Test custom formatters @@ -46,7 +47,7 @@ def test_sagemaker_env_with_custom_formatters(self): r.launch( env_vars=env, cmd= - "serve -m sagemaker_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm.zip" + "serve -m sagemaker_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm_v2.zip" ) # Test with custom formatters - use features format from existing model @@ -73,7 +74,7 @@ def test_sagemaker_csv_default_with_json_only_formatter(self): r.launch( env_vars=env, cmd= - "serve -m sagemaker_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm.zip" + "serve -m sagemaker_custom::Python=file:/opt/ml/model/sklearn_custom_model_sm_v2.zip" ) # Test should fail because output_fn only supports application/json @@ -93,7 +94,7 @@ def test_sagemaker_input_output_formatters(self): with Runner('cpu-full', 'sagemaker_input_output', download=True) as r: r.launch( cmd= - "serve -m sklearn_io::Python=file:/opt/ml/model/sklearn_custom_model_input_output.zip" + "serve -m sklearn_io::Python=file:/opt/ml/model/sklearn_custom_model_input_output_v2.zip" ) # Test with SageMaker input/output formatters @@ -124,7 +125,7 @@ def test_sagemaker_invalid_input_formatter(self): with Runner('cpu-full', 'sagemaker_invalid_input', download=True) as r: r.launch( cmd= - "serve -m sklearn_invalid::Python=file:/opt/ml/model/sklearn_custom_model_input_output_invalid.zip" + "serve -m sklearn_invalid::Python=file:/opt/ml/model/sklearn_custom_model_input_output_invalid_v2.zip" ) # Test should fail because input_fn returns raw list instead of numpy array @@ -274,7 +275,7 @@ def test_sklearn_mixed_djl_sagemaker_formatters(self): download=True) as r: r.launch( cmd= - "serve -m sklearn_mixed::Python=file:/opt/ml/model/sklearn_mixed_djl_sagemaker.zip" + "serve -m sklearn_mixed::Python=file:/opt/ml/model/sklearn_mixed_djl_sagemaker_v2.zip" ) # When DJL decorators are present, SageMaker functions should be completely ignored @@ -302,7 +303,7 @@ def test_xgboost_mixed_djl_sagemaker_formatters(self): download=True) as r: r.launch( cmd= - "serve -m xgboost_mixed::Python=file:/opt/ml/model/xgboost_mixed_djl_sagemaker.zip" + "serve -m xgboost_mixed::Python=file:/opt/ml/model/xgboost_mixed_djl_sagemaker_v2.zip" ) # When DJL decorators are present, SageMaker functions should be completely ignored @@ -329,7 +330,7 @@ def test_sklearn_djl_all_formatters(self): with Runner('cpu-full', 'sklearn_djl_all', download=True) as r: r.launch( cmd= - "serve -m sklearn_djl_all::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v3.zip" + "serve -m sklearn_djl_all::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v4.zip" ) # Test DJL decorators @@ -361,7 +362,7 @@ def test_sklearn_djl_env_with_formatters(self): r.launch( env_vars=env, cmd= - "serve -m sklearn_djl_env::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v3.zip" + "serve -m sklearn_djl_env::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v4.zip" ) test_data = { @@ -438,7 +439,7 @@ def test_xgboost_djl_all_formatters(self): with Runner('cpu-full', 'xgboost_djl_all', download=True) as r: r.launch( cmd= - "serve -m xgboost_djl_all::Python=file:/opt/ml/model/xgboost_djl_all_formatters_v3.zip" + "serve -m xgboost_djl_all::Python=file:/opt/ml/model/xgboost_djl_all_formatters.zip" ) # Test DJL decorators @@ -470,7 +471,7 @@ def test_xgboost_djl_env_with_formatters(self): r.launch( env_vars=env, cmd= - "serve -m xgboost_djl_env::Python=file:/opt/ml/model/xgboost_djl_all_formatters_v3.zip" + "serve -m xgboost_djl_env::Python=file:/opt/ml/model/xgboost_djl_all_formatters.zip" ) test_data = { diff --git a/tests/integration/test_sagemaker_compatibility.py b/tests/integration/test_sagemaker_compatibility.py index cf066d179..6335ed9e2 100644 --- a/tests/integration/test_sagemaker_compatibility.py +++ b/tests/integration/test_sagemaker_compatibility.py @@ -1,5 +1,6 @@ import os import requests +import pytest from tests import Runner From 2acd502df4064d5670278f2e76716e788ffea75f Mon Sep 17 00:00:00 2001 From: smouaa Date: Thu, 30 Oct 2025 19:14:54 +0000 Subject: [PATCH 3/3] Add sklearn and xgboost documentation and updated handlers to inherit custom formatter handler --- .../djl_python/custom_formatter_handling.py | 120 ++++---- .../setup/djl_python/sklearn_handler.py | 63 ++-- engines/python/setup/djl_python/utils.py | 22 +- .../setup/djl_python/xgboost_handler.py | 66 ++--- serving/docs/sklearn_handler.md | 268 ++++++++++++++++++ serving/docs/xgboost_handler.md | 255 +++++++++++++++++ .../ai/djl/serving/util/ConfigManager.java | 6 +- tests/integration/download_models.sh | 12 +- .../run_all_ml_sm_endpoint_tests.sh | 24 ++ .../sagemaker-ml-endpoint-tests.py | 179 +++++++++++- tests/integration/test_custom_formatters.py | 16 +- .../java/ai/djl/serving/wlm/ModelInfo.java | 6 +- 12 files changed, 868 insertions(+), 169 deletions(-) create mode 100644 serving/docs/sklearn_handler.md create mode 100644 serving/docs/xgboost_handler.md diff --git a/engines/python/setup/djl_python/custom_formatter_handling.py b/engines/python/setup/djl_python/custom_formatter_handling.py index 9e41f6b07..46304e380 100644 --- a/engines/python/setup/djl_python/custom_formatter_handling.py +++ b/engines/python/setup/djl_python/custom_formatter_handling.py @@ -17,6 +17,7 @@ from djl_python.service_loader import get_annotated_function from djl_python.utils import get_sagemaker_function +from djl_python.inputs import Input logger = logging.getLogger(__name__) @@ -30,62 +31,37 @@ def __init__(self, message: str, original_exception: Exception): self.__cause__ = original_exception -@dataclass -class CustomFormatters: - """Container for input/output formatting functions""" - input_formatter: Optional[Callable] = None - output_formatter: Optional[Callable] = None - - -@dataclass -class CustomHandlers: - """Container for prediction/initialization handler functions""" - prediction_handler: Optional[Callable] = None - init_handler: Optional[Callable] = None - - -@dataclass -class CustomCode: - """Container for all custom formatters and handlers""" - formatters: CustomFormatters - handlers: CustomHandlers - is_sagemaker_script: bool = False - - def __init__(self): - self.formatters = CustomFormatters() - self.handlers = CustomHandlers() - self.is_sagemaker_script = False - - class CustomFormatterHandler: def __init__(self): - self.custom_code = CustomCode() + self.input_formatter: Optional[Callable] = None + self.output_formatter: Optional[Callable] = None + self.prediction_handler: Optional[Callable] = None + self.init_handler: Optional[Callable] = None + self.is_sagemaker_script: bool = False - def load_formatters(self, model_dir: str) -> CustomCode: + def load_formatters(self, model_dir: str): """Load custom formatters/handlers from model.py with SageMaker detection""" try: - self.custom_code.formatters.input_formatter = get_annotated_function( + self.input_formatter = get_annotated_function( model_dir, "is_input_formatter") - self.custom_code.formatters.output_formatter = get_annotated_function( + self.output_formatter = get_annotated_function( model_dir, "is_output_formatter") - self.custom_code.handlers.prediction_handler = get_annotated_function( + self.prediction_handler = get_annotated_function( model_dir, "is_prediction_handler") - self.custom_code.handlers.init_handler = get_annotated_function( - model_dir, "is_init_handler") + self.init_handler = get_annotated_function(model_dir, + "is_init_handler") # Detect SageMaker script pattern for backward compatibility self._detect_sagemaker_functions(model_dir) logger.info( - f"Loaded formatters - input: {bool(self.custom_code.formatters.input_formatter)}, " - f"output: {bool(self.custom_code.formatters.output_formatter)}" - ) + f"Loaded formatters - input: {bool(self.input_formatter)}, " + f"output: {bool(self.output_formatter)}") logger.info( - f"Loaded handlers - prediction: {bool(self.custom_code.handlers.prediction_handler)}, " - f"init: {bool(self.custom_code.handlers.init_handler)}, " - f"sagemaker: {self.custom_code.is_sagemaker_script}") - return self.custom_code + f"Loaded handlers - prediction: {bool(self.prediction_handler)}, " + f"init: {bool(self.init_handler)}, " + f"sagemaker: {self.is_sagemaker_script}") except Exception as e: raise CustomFormatterError( f"Failed to load custom code from {model_dir}", e) @@ -94,10 +70,8 @@ def _detect_sagemaker_functions(self, model_dir: str): """Detect and load SageMaker-style functions for backward compatibility""" # If no decorator-based code found, check for SageMaker functions if not any([ - self.custom_code.formatters.input_formatter, - self.custom_code.formatters.output_formatter, - self.custom_code.handlers.prediction_handler, - self.custom_code.handlers.init_handler + self.input_formatter, self.output_formatter, + self.prediction_handler, self.init_handler ]): sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn') sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn') @@ -110,34 +84,44 @@ def _detect_sagemaker_functions(self, model_dir: str): sagemaker_model_fn, sagemaker_input_fn, sagemaker_predict_fn, sagemaker_output_fn ]): - self.custom_code.is_sagemaker_script = True + self.is_sagemaker_script = True if sagemaker_model_fn: - self.custom_code.handlers.init_handler = sagemaker_model_fn + self.init_handler = sagemaker_model_fn if sagemaker_input_fn: - self.custom_code.formatters.input_formatter = sagemaker_input_fn + self.input_formatter = sagemaker_input_fn if sagemaker_predict_fn: - self.custom_code.handlers.prediction_handler = sagemaker_predict_fn + self.prediction_handler = sagemaker_predict_fn if sagemaker_output_fn: - self.custom_code.formatters.output_formatter = sagemaker_output_fn + self.output_formatter = sagemaker_output_fn logger.info("Loaded SageMaker-style functions") def apply_input_formatter(self, decoded_payload, **kwargs): """Apply input formatter if available""" - if self.custom_code.formatters.input_formatter: + if self.input_formatter is not None: try: - return self.custom_code.formatters.input_formatter( - decoded_payload, **kwargs) + if self.is_sagemaker_script: + # SageMaker input_fn expects (data, content_type) + content_type = kwargs.get('content_type') + return self.input_formatter(decoded_payload.get_as_bytes(), + content_type) + else: + return self.input_formatter(decoded_payload, **kwargs) except Exception as e: logger.exception("Custom input formatter failed") raise CustomFormatterError( "Custom input formatter execution failed", e) return decoded_payload - def apply_output_formatter(self, output): + def apply_output_formatter(self, output, **kwargs): """Apply output formatter if available""" - if self.custom_code.formatters.output_formatter: + if self.output_formatter is not None: try: - return self.custom_code.formatters.output_formatter(output) + if self.is_sagemaker_script: + # SageMaker output_fn expects (predictions, accept) + accept = kwargs.get('accept') + return self.output_formatter(output, accept) + else: + return self.output_formatter(output) except Exception as e: logger.exception("Custom output formatter failed") raise CustomFormatterError( @@ -156,8 +140,24 @@ async def apply_output_formatter_streaming_raw(self, stream_generator): raise CustomFormatterError( "Custom streaming formatter execution failed", e) + def apply_init_handler(self, model_dir, **kwargs): + """Apply custom init handler if available""" + if self.init_handler is not None: + try: + return self.init_handler(model_dir, **kwargs) + except Exception as e: + logger.exception("Custom init handler failed") + raise CustomFormatterError( + "Custom init handler execution failed", e) + return None -def load_custom_code(model_dir: str) -> CustomCode: - """Load custom code, checking DJL decorators first, then SageMaker functions""" - handler = CustomFormatterHandler() - return handler.load_formatters(model_dir) + def apply_prediction_handler(self, X, model, **kwargs): + """Apply custom prediction handler if available""" + if self.prediction_handler is not None: + try: + return self.prediction_handler(X, model, **kwargs) + except Exception as e: + logger.exception("Custom prediction handler failed") + raise CustomFormatterError( + "Custom prediction handler execution failed", e) + return None diff --git a/engines/python/setup/djl_python/sklearn_handler.py b/engines/python/setup/djl_python/sklearn_handler.py index 4e0393dd0..956001c8d 100644 --- a/engines/python/setup/djl_python/sklearn_handler.py +++ b/engines/python/setup/djl_python/sklearn_handler.py @@ -19,16 +19,16 @@ from djl_python import Input, Output from djl_python.encode_decode import decode from djl_python.utils import find_model_file -from djl_python.custom_formatter_handling import load_custom_code +from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError from djl_python.import_utils import joblib, cloudpickle, skops_io as sio -class SklearnHandler: +class SklearnHandler(CustomFormatterHandler): def __init__(self): + super().__init__() self.model = None self.initialized = False - self.custom_code = None self.init_properties = None def _get_trusted_types(self, properties: dict): @@ -64,11 +64,15 @@ def initialize(self, properties: dict): ) # Load custom code - self.custom_code = load_custom_code(model_dir) + try: + self.load_formatters(model_dir) + except CustomFormatterError as e: + raise # Load model - if self.custom_code.handlers.init_handler: - self.model = self.custom_code.handlers.init_handler(model_dir) + init_result = self.apply_init_handler(model_dir) + if init_result is not None: + self.model = init_result else: model_file = find_model_file(model_dir, extensions) if not model_file: @@ -110,29 +114,23 @@ def inference(self, inputs: Input) -> Output: accept = default_accept # Validate accept type (skip validation if custom output formatter is provided) - if not self.custom_code.formatters.output_formatter: - supported_accept_types = ["application/json", "text/csv"] + if self.output_formatter is None: # No formatter available if not any(supported_type in accept - for supported_type in supported_accept_types): + for supported_type in ["application/json", "text/csv"]): raise ValueError( - f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}" + f"Unsupported Accept type: {accept}. Supported types: application/json, text/csv" ) # Input processing - X = None - if self.custom_code.formatters.input_formatter: - if self.custom_code.is_sagemaker_script: - X = self.custom_code.formatters.input_formatter( - inputs.get_as_bytes(), content_type) + X = self.apply_input_formatter(inputs, content_type=content_type) + if X is inputs: # No formatter applied + if "text/csv" in content_type: + X = decode(inputs, content_type, require_csv_headers=False) else: - X = self.custom_code.formatters.input_formatter(inputs) - elif "text/csv" in content_type: - X = decode(inputs, content_type, require_csv_headers=False) - else: - input_map = decode(inputs, content_type) - data = input_map.get("inputs") if isinstance(input_map, - dict) else input_map - X = np.array(data) + input_map = decode(inputs, content_type) + data = input_map.get("inputs") if isinstance( + input_map, dict) else input_map + X = np.array(data) if X is None or not hasattr(X, 'ndim'): raise ValueError( @@ -141,23 +139,18 @@ def inference(self, inputs: Input) -> Output: if X.ndim == 1: X = X.reshape(1, -1) - if self.custom_code.handlers.prediction_handler: - predictions = self.custom_code.handlers.prediction_handler( - X, self.model) - else: + predictions = self.apply_prediction_handler(X, self.model) + if predictions is None: predictions = self.model.predict(X) # Output processing outputs = Output() - if self.custom_code.formatters.output_formatter: - if self.custom_code.is_sagemaker_script: - data = self.custom_code.formatters.output_formatter( - predictions, accept) + formatted_output = self.apply_output_formatter(predictions, + accept=accept) + if formatted_output is not predictions: # Formatter was applied + if self.is_sagemaker_script: outputs.add_property("Content-Type", accept) - else: - data = self.custom_code.formatters.output_formatter( - predictions) - outputs.add(data) + outputs.add(formatted_output) elif "text/csv" in accept: csv_buffer = StringIO() np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',') diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index c27764aaa..469136f9f 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -23,10 +23,12 @@ # SageMaker function signatures for validation SAGEMAKER_SIGNATURES = { - 'model_fn': ['model_dir'], - 'input_fn': ['request_body', 'content_type'], - 'predict_fn': ['input_data', 'model'], - 'output_fn': ['prediction', 'accept'] + 'model_fn': [['model_dir']], + 'input_fn': [['request_body', 'content_type'], + ['request_body', 'request_content_type']], + 'predict_fn': [['input_data', 'model']], + 'output_fn': [['prediction', 'accept'], + ['prediction', 'response_content_type']] } @@ -201,9 +203,8 @@ def find_model_file(model_dir: str, extensions: List[str]) -> Optional[str]: return all_matches[0] if all_matches else None -def _validate_sagemaker_function( - module, func_name: str, - expected_params: List[str]) -> Optional[Callable]: +def _validate_sagemaker_function(module, func_name: str, + expected_params) -> Optional[Callable]: """ Validate that function exists and has correct signature Returns the function if valid, None otherwise @@ -219,9 +220,10 @@ def _validate_sagemaker_function( sig = inspect.signature(func) param_names = list(sig.parameters.keys()) - # Check parameter count and names match exactly - if param_names == expected_params: - return func + # Handle multiple signature options + for signature_option in expected_params: + if param_names == signature_option: + return func except (ValueError, TypeError): # Handle cases where signature inspection fails pass diff --git a/engines/python/setup/djl_python/xgboost_handler.py b/engines/python/setup/djl_python/xgboost_handler.py index 23244ff06..70a79545e 100644 --- a/engines/python/setup/djl_python/xgboost_handler.py +++ b/engines/python/setup/djl_python/xgboost_handler.py @@ -19,16 +19,16 @@ from djl_python import Input, Output from djl_python.encode_decode import decode from djl_python.utils import find_model_file -from djl_python.custom_formatter_handling import load_custom_code +from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError from djl_python.import_utils import xgboost as xgb -class XGBoostHandler: +class XGBoostHandler(CustomFormatterHandler): def __init__(self): + super().__init__() self.model = None self.initialized = False - self.custom_code = None self.init_properties = None def initialize(self, properties: dict): @@ -52,10 +52,14 @@ def initialize(self, properties: dict): ) # Load custom code - self.custom_code = load_custom_code(model_dir) - - if self.custom_code.handlers.init_handler: - self.model = self.custom_code.handlers.init_handler(model_dir) + try: + self.load_formatters(model_dir) + except CustomFormatterError as e: + raise + + init_result = self.apply_init_handler(model_dir) + if init_result is not None: + self.model = init_result else: model_file = find_model_file(model_dir, extensions) if not model_file: @@ -98,29 +102,23 @@ def inference(self, inputs: Input) -> Output: accept = default_accept # Validate accept type (skip validation if custom output formatter is provided) - if not self.custom_code.formatters.output_formatter: - supported_accept_types = ["application/json", "text/csv"] + if self.output_formatter is None: # No formatter available if not any(supported_type in accept - for supported_type in supported_accept_types): + for supported_type in ["application/json", "text/csv"]): raise ValueError( - f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}" + f"Unsupported Accept type: {accept}. Supported types: application/json, text/csv" ) # Input processing - X = None - if self.custom_code.formatters.input_formatter: - if self.custom_code.is_sagemaker_script: - X = self.custom_code.formatters.input_formatter( - inputs.get_as_bytes(), content_type) + X = self.apply_input_formatter(inputs, content_type=content_type) + if X is inputs: # No formatter applied + if "text/csv" in content_type: + X = decode(inputs, content_type, require_csv_headers=False) else: - X = self.custom_code.formatters.input_formatter(inputs) - elif "text/csv" in content_type: - X = decode(inputs, content_type, require_csv_headers=False) - else: - input_map = decode(inputs, content_type) - data = input_map.get("inputs") if isinstance(input_map, - dict) else input_map - X = np.array(data) + input_map = decode(inputs, content_type) + data = input_map.get("inputs") if isinstance( + input_map, dict) else input_map + X = np.array(data) if X is None or not hasattr(X, 'ndim'): raise ValueError( @@ -128,25 +126,19 @@ def inference(self, inputs: Input) -> Output: if X.ndim == 1: X = X.reshape(1, -1) - if self.custom_code.handlers.prediction_handler: - predictions = self.custom_code.handlers.prediction_handler( - X, self.model) - else: + predictions = self.apply_prediction_handler(X, self.model) + if predictions is None: dmatrix = xgb.DMatrix(X) predictions = self.model.predict(dmatrix) # Output processing outputs = Output() - if self.custom_code.formatters.output_formatter: - if self.custom_code.is_sagemaker_script: - data = self.custom_code.formatters.output_formatter( - predictions, accept) + formatted_output = self.apply_output_formatter(predictions, + accept=accept) + if formatted_output is not predictions: # Formatter was applied + if self.is_sagemaker_script: outputs.add_property("Content-Type", accept) - else: - data = self.custom_code.formatters.output_formatter( - predictions) - outputs.add(data) - + outputs.add(formatted_output) elif "text/csv" in accept: csv_buffer = StringIO() np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',') diff --git a/serving/docs/sklearn_handler.md b/serving/docs/sklearn_handler.md new file mode 100644 index 000000000..c3f6c42eb --- /dev/null +++ b/serving/docs/sklearn_handler.md @@ -0,0 +1,268 @@ +# Scikit-learn Handler + +The scikit-learn handler enables serving scikit-learn models with DJL Serving's Python engine. It supports multiple model formats and provides flexible customization options. + +## Supported Model Formats + +### Secure Formats (Require option.skops_trusted_types='list,of,types,separated,by,commas') +- **skops**: `.skops` files + +### Insecure Formats (Require option.trust_insecure_model_files=true) +- **joblib**: `.joblib`, `.jl` files +- **pickle**: `.pkl`, `.pickle` files +- **cloudpickle**: `.pkl`, `.pickle`, `.cloudpkl` files + +## Configuration + +The scikit-learn handler accepts configurations in two formats: + +* `serving.properties` Configuration File (per model configurations) +* Environment Variables (global configurations) + +For most use-cases, using environment variables is sufficient. +Configurations specified in the `serving.properties` files will override configurations specified in environment variables. + +### serving.properties Configuration + +Create a `serving.properties` file: + +```properties +engine=Python +option.entryPoint=djl_python.sklearn_handler +option.model_format=skops +option.skops_trusted_types=sklearn.ensemble._forest.RandomForestClassifier,numpy.ndarray +``` + +For insecure formats, add: +```properties +option.trust_insecure_model_files=true +``` + +### Environment Variable Configuration + +Alternatively, configure via environment variables: + +```python +env = { + 'OPTION_ENGINE': 'Python', + 'OPTION_ENTRY_POINT': 'djl_python.sklearn_handler', + 'OPTION_MODEL_FORMAT': 'skops', + 'OPTION_SKOPS_TRUSTED_TYPES': 'sklearn.ensemble._forest.RandomForestClassifier,numpy.ndarray', + 'OPTION_TRUST_INSECURE_MODEL_FILES': 'false' +} +``` + +Configuration keys that start with `option.` can be specified as environment variables using the `OPTION_` prefix. +The configuration `option.` is translated to environment variable `OPTION_`. + +### Model Directory Structure + +``` +model/ +├── serving.properties # Optional: If absent from model directory, must set with ENV variables +├── model.skops # Your scikit-learn model file +└── model.py # Optional: Custom handlers +``` + +### Default Input/Output + +The scikit-learn handler supports both JSON and CSV input/output formats. + +**JSON Input Format:** +```json +{ + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0]] +} +``` + +Also accepts direct arrays format: +```json +[ + [1.0, 2.0, 3.0, 4.0, 5.0], + [2.0, 3.0, 4.0, 5.0, 6.0] +] +``` + +**JSON Output Format:** +```json +{ + "predictions": [1] +} +``` + +**CSV Input Format:** +``` +1.0,2.0,3.0,4.0,5.0 +2.0,3.0,4.0,5.0,6.0 +``` + +**CSV Output Format:** +``` +1 +0 +``` +Each prediction is returned on a separate line. + +## Custom Formatters/Handlers + +There are two ways to customize the scikit-learn handler behavior: + +### Method 1: DJL Decorators + +Create a `model.py` file using DJL decorators: + +```python +from djl_python.input_parser import input_formatter, prediction_handler, init_handler +from djl_python.output_formatter import output_formatter +from djl_python import Input +import numpy as np +import joblib +import os + +@init_handler +def custom_init(model_dir, **kwargs): + """Custom model initialization""" + model_path = os.path.join(model_dir, "model.skops") + import skops.io as sio + trusted_types = ['sklearn.ensemble._forest.RandomForestClassifier', 'numpy.ndarray'] + model = sio.load(model_path, trusted=trusted_types) + return model + +@input_formatter +def custom_input(inputs: Input, **kwargs): + """Custom input processing - returns numpy array for default predict""" + data = inputs.get_as_json() + features = data.get("features", data.get("inputs")) + X = np.array(features) + if X.ndim == 1: + X = X.reshape(1, -1) + return X + +@prediction_handler +def custom_predict(X, model, **kwargs): + """Custom prediction logic""" + predictions = model.predict(X) + if hasattr(model, 'predict_proba'): + probabilities = model.predict_proba(X) + return {"predictions": predictions, "probabilities": probabilities} + return predictions + +@output_formatter +def custom_output(predictions): + """Custom output formatting""" + if isinstance(predictions, dict): + return predictions + return {"predictions": predictions.tolist()} +``` + +### Method 2: SageMaker-Style Functions + +Alternatively, use SageMaker-compatible function signatures: + +```python +import numpy as np +import joblib +import json +import os + +def model_fn(model_dir): + """Load model - model_dir is the directory name""" + model_path = os.path.join(model_dir, "model.skops") + import skops.io as sio + trusted_types = ['sklearn.ensemble._forest.RandomForestClassifier', 'numpy.ndarray'] + model = sio.load(model_path, trusted=trusted_types) + return model + +def input_fn(request_body, request_content_type): + """Parse input - request_body is byte buffer (request_content_type can also be named content_type)""" + if request_content_type == 'application/json': + data = json.loads(request_body.decode('utf-8')) + features = data.get("features", data.get("inputs")) + return np.array(features) # Return numpy for default predict + elif request_content_type == 'text/csv': + import io + data = np.loadtxt(io.StringIO(request_body.decode('utf-8')), delimiter=',') + return data + else: + raise ValueError(f"Unsupported content type: {request_content_type}") + +def predict_fn(input_object, model): + """Run prediction""" + return model.predict(input_object) + +def output_fn(prediction, response_content_type): + """Format output - returns byte array (response_content_type can also be named accept)""" + if response_content_type == 'application/json': + result = {"predictions": prediction.tolist()} + return json.dumps(result).encode('utf-8') + elif response_content_type == 'text/csv': + return '\n'.join(map(str, prediction)).encode('utf-8') + else: + raise ValueError(f"Unsupported accept type: {response_content_type}") +``` + +### Important Notes + +- **DJL decorators take precedence** over SageMaker functions - cannot mix both approaches +- **Omitted functions use defaults**: If any decorator/function is omitted, the handler uses default logic +- **Default predict expects numpy**: When using custom input formatters, return numpy arrays for default prediction +- **Default output processes numpy**: Default output formatting expects numpy arrays from prediction + +### SageMaker Model Example +```python +from sagemaker import Model + +# Environment variables for scikit-learn handler +env = { + 'SAGEMAKER_MODEL_SERVER_VMARGS': '-Xmx2g -Xms2g', + 'SAGEMAKER_STARTUP_TIMEOUT': '600', + 'SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS': '240', + 'SAGEMAKER_MAX_PAYLOAD_IN_MB': '10', + 'SAGEMAKER_NUM_MODEL_WORKERS': '1' +} + +# Create SageMaker model +model = Model( + image_uri='763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.35.0-cpu-full', + model_data='s3://your-bucket/sklearn-model.tar.gz', + role=role, + env=env +) + +# Deploy endpoint +predictor = model.deploy( + initial_instance_count=1, + instance_type='ml.m5.xlarge' +) +``` + +### Available Environment Variables + +#### Handler Configuration +| Variable | Description | Default | Example | +|----------|-------------|---------|----------| +| `OPTION_ENTRY_POINT` | Handler entry point | - | `djl_python.sklearn_handler` | +| `OPTION_MODEL_FORMAT` | Model file format | - | `skops`, `joblib`, `pickle`, `cloudpickle` | +| `OPTION_SKOPS_TRUSTED_TYPES` | Trusted types for skops | - | `sklearn.ensemble._forest.RandomForestClassifier,numpy.ndarray` | +| `OPTION_TRUST_INSECURE_MODEL_FILES` | Allow insecure formats | `false` | `true` | + +#### SageMaker-Specific Variables (Backwards Compatible) +These variables work with any DJL deployment - not just SageMaker endpoints. +| Variable | Description | Example | +|----------|-------------|----------| +| `SAGEMAKER_MAX_REQUEST_SIZE` | Max request size (bytes) | `10485760` | +| `SAGEMAKER_NUM_MODEL_WORKERS` | Number of model workers | `2` | +| `SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT` | Default accept header | `application/json` | +| `SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS` | Model server prediction timeout | `240` | +| `SAGEMAKER_MODEL_SERVER_VMARGS` | JVM arguments | `-Xmx4g -Xms4g` | +| `SAGEMAKER_STARTUP_TIMEOUT` | Startup timeout (seconds) | `600` | +| `SAGEMAKER_MAX_PAYLOAD_IN_MB` | Max payload size (MB) | `10` | + +## Troubleshooting + +### Common Issues + +1. **Model format not recognized**: Ensure file extension matches format +2. **Security error**: Set appropriate trusted types for skops or enable trust_insecure_model_files +3. **Input shape mismatch**: Verify input dimensions match training data or if using custom formatters, check formatter logic +4. **Import errors**: Ensure all required scikit-learn modules are available diff --git a/serving/docs/xgboost_handler.md b/serving/docs/xgboost_handler.md new file mode 100644 index 000000000..9840db4e5 --- /dev/null +++ b/serving/docs/xgboost_handler.md @@ -0,0 +1,255 @@ +# XGBoost Handler + +The XGBoost handler enables serving XGBoost models with DJL Serving's Python engine. It supports multiple model formats and provides flexible customization options. + +## Supported Model Formats + +### Secure Formats +- **json**: `.json` files +- **ubj**: `.ubj` files + +### Insecure Formats (Require option.trust_insecure_model_files=true) +- **pickle**: `.pkl`, `.pickle` files +- **xgb**: `.xgb`, `.model`, `.bst` files + +## Configuration + +The XGBoost handler accepts configurations in two formats: + +* `serving.properties` Configuration File (per model configurations) +* Environment Variables (global configurations) + +For most use-cases, using environment variables is sufficient. +Configurations specified in the `serving.properties` files will override configurations specified in environment variables. + +### serving.properties Configuration + +Create a `serving.properties` file: + +```properties +engine=Python +option.entryPoint=djl_python.xgboost_handler +option.model_format=json +``` + +For insecure formats, add: +```properties +option.trust_insecure_model_files=true +``` + +### Environment Variable Configuration + +Alternatively, configure via environment variables: + +```python +env = { + 'OPTION_ENGINE': 'Python', + 'OPTION_ENTRY_POINT': 'djl_python.xgboost_handler', + 'OPTION_MODEL_FORMAT': 'json', + 'OPTION_TRUST_INSECURE_MODEL_FILES': 'false' +} +``` + +Configuration keys that start with `option.` can be specified as environment variables using the `OPTION_` prefix. +The configuration `option.` is translated to environment variable `OPTION_`. + +### Model Directory Structure + +``` +model/ +├── serving.properties # Optional: If absent from model directory, must set with ENV variables +├── model.json # Your XGBoost model file +└── model.py # Optional: Custom handlers +``` + +### Default Input/Output + +The XGBoost handler supports both JSON and CSV input/output formats. + +**JSON Input Format:** +```json +{ + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0]] +} +``` + +Also accepts direct arrays format: +```json +[ + [1.0, 2.0, 3.0, 4.0, 5.0], + [2.0, 3.0, 4.0, 5.0, 6.0] +] +``` + +**JSON Output Format:** +```json +{ + "predictions": [0.8234] +} +``` + +**CSV Input Format:** +``` +1.0,2.0,3.0,4.0,5.0 +2.0,3.0,4.0,5.0,6.0 +``` + +**CSV Output Format:** +``` +0.8234 +0.7456 +``` +Each prediction is returned on a separate line. + +## Custom Formatters/Handlers + +There are two ways to customize the XGBoost handler behavior: + +### Method 1: DJL Decorators + +Create a `model.py` file using DJL decorators: + +```python +from djl_python.input_parser import input_formatter, prediction_handler, init_handler +from djl_python.output_formatter import output_formatter +from djl_python import Input +import numpy as np +import xgboost as xgb +import os + +@init_handler +def custom_init(model_dir, **kwargs): + """Custom model initialization""" + model_path = os.path.join(model_dir, "model.json") + model = xgb.Booster() + model.load_model(model_path) + return model + +@input_formatter +def custom_input(inputs: Input, **kwargs): + """Custom input processing - returns numpy array for default predict""" + data = inputs.get_as_json() + features = data.get("features", data.get("inputs")) + return np.array(features) # Default predict expects numpy + +@prediction_handler +def custom_predict(X, model, **kwargs): + """Custom prediction logic""" + dmatrix = xgb.DMatrix(X) + return model.predict(dmatrix) + +@output_formatter +def custom_output(predictions): + """Custom output formatting""" + return {"predictions": predictions.tolist()} +``` + +### Method 2: SageMaker-Style Functions + +Alternatively, use SageMaker-compatible function signatures in `model.py`: + +```python +import numpy as np +import xgboost as xgb +import json +import os + +def model_fn(model_dir): + """Load model - model_dir is the directory name""" + model_path = os.path.join(model_dir, "model.json") + model = xgb.Booster() + model.load_model(model_path) + return model + +def input_fn(request_body, request_content_type): + """Parse input - request_body is byte buffer (request_content_type can also be named content_type)""" + if request_content_type == 'application/json': + data = json.loads(request_body.decode('utf-8')) + features = data.get("features", data.get("inputs")) + return np.array(features) # Return numpy for default predict + elif request_content_type == 'text/csv': + import io + data = np.loadtxt(io.StringIO(request_body.decode('utf-8')), delimiter=',') + return data + else: + raise ValueError(f"Unsupported content type: {request_content_type}") + +def predict_fn(input_object, model): + """Run prediction - input_object from input_fn""" + dmatrix = xgb.DMatrix(input_object) + return model.predict(dmatrix) + +def output_fn(prediction, response_content_type): + """Format output - returns byte array (response_content_type can also be named accept)""" + if response_content_type == 'application/json': + result = {"predictions": prediction.tolist()} + return json.dumps(result).encode('utf-8') + elif response_content_type == 'text/csv': + return '\n'.join(map(str, prediction)).encode('utf-8') + else: + raise ValueError(f"Unsupported accept type: {response_content_type}") +``` + +### Important Notes + +- **DJL decorators take precedence** over SageMaker functions - cannot mix both approaches +- **Omitted functions use defaults**: If any decorator/function is omitted, the handler uses default logic +- **Default predict expects numpy**: When using custom input formatters, return numpy arrays for default prediction +- **Default output processes numpy**: Default output formatting expects numpy arrays from prediction + +### SageMaker Model Example +```python +from sagemaker import Model + +# Environment variables for XGBoost handler +env = { + 'SAGEMAKER_MODEL_SERVER_VMARGS': '-Xmx2g -Xms2g', + 'SAGEMAKER_STARTUP_TIMEOUT': '600', + 'SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS': '240', + 'SAGEMAKER_MAX_PAYLOAD_IN_MB': '10', + 'SAGEMAKER_NUM_MODEL_WORKERS': '1' +} + +# Create SageMaker model +model = Model( + image_uri='763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.35.0-cpu-full', + model_data='s3://your-bucket/xgboost-model.tar.gz', + role=role, + env=env +) + +# Deploy endpoint +predictor = model.deploy( + initial_instance_count=1, + instance_type='ml.m5.xlarge' +) +``` + +### Available Environment Variables + +#### Handler Configuration +| Variable | Description | Default | Example | +|----------|-------------|---------|----------| +| `OPTION_ENTRY_POINT` | Handler entry point | - | `djl_python.xgboost_handler` | +| `OPTION_MODEL_FORMAT` | Model file format | - | `json`, `ubj`, `pickle`, `xgb` | +| `OPTION_TRUST_INSECURE_MODEL_FILES` | Allow insecure formats | `false` | `true` | + +#### SageMaker-Specific Variables (Backwards Compatible) +These variables work with any DJL deployment - not just SageMaker endpoints. +| Variable | Description | Example | +|----------|-------------|----------| +| `SAGEMAKER_MAX_REQUEST_SIZE` | Max request size (bytes) | `10485760` | +| `SAGEMAKER_NUM_MODEL_WORKERS` | Number of model workers | `2` | +| `SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT` | Default accept header | `application/json` | +| `SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS` | Model server prediction timeout | `240` | +| `SAGEMAKER_MODEL_SERVER_VMARGS` | JVM arguments | `-Xmx4g -Xms4g` | +| `SAGEMAKER_STARTUP_TIMEOUT` | Startup timeout (seconds) | `600` | +| `SAGEMAKER_MAX_PAYLOAD_IN_MB` | Max payload size (MB) | `10` | + +## Troubleshooting + +### Common Issues + +1. **Model format not recognized**: Ensure file extension matches format +2. **Security error**: Set `trust_insecure_model_files=true` for pickle/xgb formats +3. **Input shape mismatch**: Verify input dimensions match training data or if using custom formatters, check formatter logic \ No newline at end of file diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index ec30ad371..235926c93 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -128,8 +128,6 @@ private ConfigManager(Arguments args) { if (models != null) { prop.setProperty(LOAD_MODELS, String.join(",", models)); } - // Apply SageMaker compatibility for server-level configurations - SageMakerCompatibility.applyServerCompatibility(prop); Map env = Utils.getenv(); @@ -139,6 +137,10 @@ private ConfigManager(Arguments args) { prop.put(key.substring(8).toLowerCase(Locale.ROOT), entry.getValue()); } } + + // Apply SageMaker compatibility for server-level configurations + SageMakerCompatibility.applyServerCompatibility(prop); + for (Map.Entry entry : prop.entrySet()) { String key = (String) entry.getKey(); if (key.startsWith("error_rate_")) { diff --git a/tests/integration/download_models.sh b/tests/integration/download_models.sh index b8804bd39..6d5e2455a 100755 --- a/tests/integration/download_models.sh +++ b/tests/integration/download_models.sh @@ -39,9 +39,9 @@ python_skl_models_urls=( "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output_v2.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_input_output_invalid_v2.zip" "https://resources.djl.ai/test-models/python/sklearn/sklearn_mixed_djl_sagemaker_v2.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_all_formatters_v4.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_input_output_v3.zip" - "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_invalid_input_v3.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_all_formatters.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_input_output.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_djl_invalid_input.zip" "https://resources.djl.ai/test-models/python/sklearn/slow_loading_model.zip" "https://resources.djl.ai/test-models/python/sklearn/slow_predict_model.zip" ) @@ -56,9 +56,9 @@ python_xgb_models_urls=( "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_input_output.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_sagemaker_input_output_invalid.zip" "https://resources.djl.ai/test-models/python/xgboost/xgboost_mixed_djl_sagemaker_v2.zip" - "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_all_formatters.zip" - "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_input_output_v3.zip" - "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_invalid_input_v3.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_all_formatters_v1.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_input_output.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_djl_invalid_input.zip" ) download() { diff --git a/tests/integration/run_all_ml_sm_endpoint_tests.sh b/tests/integration/run_all_ml_sm_endpoint_tests.sh index 9b276d47d..a42d7e8f4 100755 --- a/tests/integration/run_all_ml_sm_endpoint_tests.sh +++ b/tests/integration/run_all_ml_sm_endpoint_tests.sh @@ -26,6 +26,11 @@ MULTI_MODELS=( "sklearn-multi" ) +# Multi-container tests +MULTI_CONTAINER_MODELS=( + "sklearn-xgboost-multi-container" +) + for model in "${MODELS[@]}"; do echo "" | tee -a "$OUTPUT_FILE" echo "Testing: $model with JSON and CSV" | tee -a "$OUTPUT_FILE" @@ -82,6 +87,25 @@ for model in "${MULTI_MODELS[@]}"; do echo "----------------------------------------" | tee -a "$OUTPUT_FILE" done +# Test multi-container endpoints +for model in "${MULTI_CONTAINER_MODELS[@]}"; do + echo "" | tee -a "$OUTPUT_FILE" + echo "Testing: $model multi-container endpoint" | tee -a "$OUTPUT_FILE" + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" + + if python3 sagemaker-ml-endpoint-tests.py "$model" "$IMAGE_TYPE" --test-multi-container 2>&1 | tee -a "$OUTPUT_FILE"; then + if grep -q "Multi-container endpoint test completed successfully" "$OUTPUT_FILE"; then + echo "SUCCESS: $model (multi-container)" | tee -a "$OUTPUT_FILE" + else + echo "FAILED: $model (multi-container) - No success message found" | tee -a "$OUTPUT_FILE" + fi + else + echo "FAILED: $model (multi-container) - Script returned error" | tee -a "$OUTPUT_FILE" + fi + + echo "----------------------------------------" | tee -a "$OUTPUT_FILE" +done + echo "" | tee -a "$OUTPUT_FILE" echo "All tests completed at $(date)" | tee -a "$OUTPUT_FILE" echo "Results saved to: $OUTPUT_FILE" \ No newline at end of file diff --git a/tests/integration/sagemaker-ml-endpoint-tests.py b/tests/integration/sagemaker-ml-endpoint-tests.py index 3fa975c5b..6da3bd864 100644 --- a/tests/integration/sagemaker-ml-endpoint-tests.py +++ b/tests/integration/sagemaker-ml-endpoint-tests.py @@ -14,7 +14,6 @@ DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" DEFAULT_BUCKET = "sm-integration-tests-rubikon-usw2" -# DJL Serving CPU images CANDIDATE_IMAGES = { "cpu-full": "125045733377.dkr.ecr.us-west-2.amazonaws.com/djl-serving-cpu-full-test:latest" @@ -43,7 +42,7 @@ }, "sklearn-djl-formatters": { "model_data": - "s3://djl-llm-sm-endpoint-tests/skl_xgb/sklearn_djl_all_formatters_v4.tar", + "s3://djl-llm-sm-endpoint-tests/skl_xgb/sklearn_djl_all_formatters.tar", "payload": { "features": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] }, @@ -170,6 +169,47 @@ "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", "SAGEMAKER_NUM_MODEL_WORKERS": "2" } + }, +} + +# Multi-container endpoint configuration +MULTI_CONTAINER_CONFIGS = { + "sklearn-xgboost-multi-container": { + "containers": [{ + "name": "sklearn-container", + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/sklearn_skops_model.tar", + "env_vars": { + "OPTION_SKOPS_TRUSTED_TYPES": + "sklearn.ensemble._forest.RandomForestClassifier", + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + }, { + "name": "xgboost-container", + "model_data": + "s3://djl-llm-sm-endpoint-tests/skl_xgb/xgboost_model.tar", + "env_vars": { + "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json", + "SAGEMAKER_NUM_MODEL_WORKERS": "2" + } + }], + "test_payloads": { + "sklearn": { + "payload": { + "inputs": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + }, + "xgboost": { + "payload": { + "inputs": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + }, + "csv_payload": "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + } + } } } @@ -205,6 +245,9 @@ def parse_args(): parser.add_argument("--test-multi-model", help="Test multi-model endpoint", action="store_true") + parser.add_argument("--test-multi-container", + help="Test multi-container endpoint", + action="store_true") return parser.parse_args() @@ -306,11 +349,10 @@ def test_endpoint(framework, batch_msg = " with batch" if test_batch else "" print( - f"✓ Successfully tested {framework} model: {model_name}{batch_msg}" - ) + f"Successfully tested {framework} model: {model_name}{batch_msg}") except Exception as e: - print(f"✗ Error testing {framework} model {model_name}: {e}") + print(f"Error testing {framework} model {model_name}: {e}") raise e finally: if predictor: @@ -383,10 +425,10 @@ def test_multi_model_endpoint(model_name, image_type, test_format="json"): predictor.serializer = JSONSerializer() predictor.deserializer = JSONDeserializer() - print(f"✓ Successfully tested multi-model endpoint: {model_name}") + print(f"Successfully tested multi-model endpoint: {model_name}") except Exception as e: - print(f"✗ Error testing multi-model endpoint {model_name}: {e}") + print(f"Error testing multi-model endpoint {model_name}: {e}") raise e finally: if predictor: @@ -395,6 +437,119 @@ def test_multi_model_endpoint(model_name, image_type, test_format="json"): model.delete_model() +def test_multi_container_endpoint(config_name, image_uri): + """Test multi-container endpoint with sklearn and xgboost models""" + import time + + if config_name not in MULTI_CONTAINER_CONFIGS: + raise ValueError(f"Unknown multi-container config: {config_name}") + + config = MULTI_CONTAINER_CONFIGS[config_name] + endpoint_name = f"djl-mc-{config_name}-{int(time.time())}" + + print(f"Testing multi-container endpoint: {endpoint_name}") + + sagemaker_client = boto3.client('sagemaker', region_name='us-west-2') + runtime_client = boto3.client('sagemaker-runtime', region_name='us-west-2') + + model_name = None + endpoint_config_name = None + + try: + # Create container definitions + container_defs = [] + for container_config in config["containers"]: + container_def = { + "Image": image_uri, + "ModelDataUrl": container_config["model_data"], + "Environment": container_config["env_vars"], + "ContainerHostname": container_config["name"] + } + container_defs.append(container_def) + + # Create model with Direct inference execution mode for multi-container + model_name = f"djl-mc-model-{int(time.time())}" + sagemaker_client.create_model( + ModelName=model_name, + ExecutionRoleArn=ROLE, + Containers=container_defs, + InferenceExecutionConfig={"Mode": "Direct"}) + + # Create endpoint configuration + endpoint_config_name = f"djl-mc-config-{int(time.time())}" + sagemaker_client.create_endpoint_config( + EndpointConfigName=endpoint_config_name, + ProductionVariants=[{ + 'VariantName': 'AllTraffic', + 'ModelName': model_name, + 'InitialInstanceCount': 1, + 'InstanceType': DEFAULT_INSTANCE_TYPE, + 'InitialVariantWeight': 1 + }]) + + # Create endpoint + sagemaker_client.create_endpoint( + EndpointName=endpoint_name, + EndpointConfigName=endpoint_config_name) + + # Wait for endpoint to be in service + print(f"Waiting for endpoint {endpoint_name} to be in service...") + waiter = sagemaker_client.get_waiter('endpoint_in_service') + waiter.wait(EndpointName=endpoint_name, + WaiterConfig={ + 'Delay': 30, + 'MaxAttempts': 60 + }) + + # Test predictions on both containers using TargetContainerHostname + for container_name, test_payload in config["test_payloads"].items(): + print(f"Testing {container_name} container...") + + container_hostname = next(c["name"] for c in config["containers"] + if container_name in c["name"]) + + # Test JSON payload + response = runtime_client.invoke_endpoint( + EndpointName=endpoint_name, + ContentType='application/json', + Accept='application/json', + Body=json.dumps(test_payload["payload"]), + TargetContainerHostname=container_hostname) + + result = json.loads(response['Body'].read().decode()) + print(f"{container_name} JSON prediction result: {result}") + + # Test CSV payload + response = runtime_client.invoke_endpoint( + EndpointName=endpoint_name, + ContentType='text/csv', + Accept='application/json', + Body=test_payload["csv_payload"], + TargetContainerHostname=container_hostname) + + result = json.loads(response['Body'].read().decode()) + print(f"{container_name} CSV prediction result: {result}") + + print(f"Multi-container endpoint test completed successfully!") + + except Exception as e: + print(f"Multi-container endpoint test failed: {str(e)}") + raise + finally: + # Cleanup + try: + if endpoint_name: + sagemaker_client.delete_endpoint(EndpointName=endpoint_name) + if endpoint_config_name: + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint_config_name) + if model_name: + sagemaker_client.delete_model(ModelName=model_name) + print(f"Cleaned up multi-container endpoint resources") + except Exception as cleanup_error: + print(f"Error during cleanup: {cleanup_error}") + + if __name__ == "__main__": args = parse_args() @@ -406,8 +561,16 @@ def test_multi_model_endpoint(model_name, image_type, test_format="json"): else: test_format = "json" # Default + # Handle multi-container endpoint testing + if args.test_multi_container: + if args.model_name not in MULTI_CONTAINER_CONFIGS: + raise ValueError( + f"Unknown multi-container config: {args.model_name}. Available: {list(MULTI_CONTAINER_CONFIGS.keys())}" + ) + test_multi_container_endpoint(args.model_name, + get_image_uri(args.image_type)) # Handle multi-model endpoint testing - if args.test_multi_model: + elif args.test_multi_model: if args.model_name not in MULTI_MODEL_CONFIGS: raise ValueError( f"Unknown multi-model config: {args.model_name}. Available: {list(MULTI_MODEL_CONFIGS.keys())}" diff --git a/tests/integration/test_custom_formatters.py b/tests/integration/test_custom_formatters.py index 027e7d1bf..8ad39d95c 100644 --- a/tests/integration/test_custom_formatters.py +++ b/tests/integration/test_custom_formatters.py @@ -330,7 +330,7 @@ def test_sklearn_djl_all_formatters(self): with Runner('cpu-full', 'sklearn_djl_all', download=True) as r: r.launch( cmd= - "serve -m sklearn_djl_all::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v4.zip" + "serve -m sklearn_djl_all::Python=file:/opt/ml/model/sklearn_djl_all_formatters.zip" ) # Test DJL decorators @@ -362,7 +362,7 @@ def test_sklearn_djl_env_with_formatters(self): r.launch( env_vars=env, cmd= - "serve -m sklearn_djl_env::Python=file:/opt/ml/model/sklearn_djl_all_formatters_v4.zip" + "serve -m sklearn_djl_env::Python=file:/opt/ml/model/sklearn_djl_all_formatters.zip" ) test_data = { @@ -385,7 +385,7 @@ def test_sklearn_djl_input_output_formatters(self): download=True) as r: r.launch( cmd= - "serve -m sklearn_djl_io::Python=file:/opt/ml/model/sklearn_djl_input_output_v3.zip" + "serve -m sklearn_djl_io::Python=file:/opt/ml/model/sklearn_djl_input_output.zip" ) test_data = { @@ -417,7 +417,7 @@ def test_sklearn_djl_invalid_input_formatter(self): download=True) as r: r.launch( cmd= - "serve -m sklearn_djl_invalid::Python=file:/opt/ml/model/sklearn_djl_invalid_input_v3.zip" + "serve -m sklearn_djl_invalid::Python=file:/opt/ml/model/sklearn_djl_invalid_input.zip" ) # Test should fail because input_formatter returns raw list instead of numpy array @@ -439,7 +439,7 @@ def test_xgboost_djl_all_formatters(self): with Runner('cpu-full', 'xgboost_djl_all', download=True) as r: r.launch( cmd= - "serve -m xgboost_djl_all::Python=file:/opt/ml/model/xgboost_djl_all_formatters.zip" + "serve -m xgboost_djl_all::Python=file:/opt/ml/model/xgboost_djl_all_formatters_v1.zip" ) # Test DJL decorators @@ -471,7 +471,7 @@ def test_xgboost_djl_env_with_formatters(self): r.launch( env_vars=env, cmd= - "serve -m xgboost_djl_env::Python=file:/opt/ml/model/xgboost_djl_all_formatters.zip" + "serve -m xgboost_djl_env::Python=file:/opt/ml/model/xgboost_djl_all_formatters_v1.zip" ) test_data = { @@ -494,7 +494,7 @@ def test_xgboost_djl_input_output_formatters(self): download=True) as r: r.launch( cmd= - "serve -m xgboost_djl_io::Python=file:/opt/ml/model/xgboost_djl_input_output_v3.zip" + "serve -m xgboost_djl_io::Python=file:/opt/ml/model/xgboost_djl_input_output.zip" ) test_data = { @@ -526,7 +526,7 @@ def test_xgboost_djl_invalid_input_formatter(self): download=True) as r: r.launch( cmd= - "serve -m xgboost_djl_invalid::Python=file:/opt/ml/model/xgboost_djl_invalid_input_v3.zip" + "serve -m xgboost_djl_invalid::Python=file:/opt/ml/model/xgboost_djl_invalid_input.zip" ) # Test should fail because input_formatter returns raw list instead of numpy array diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index ae675ee29..d61f5e960 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -890,9 +890,6 @@ private void loadServingProperties() { } } - // Apply SageMaker compatibility for model-level configurations - SageMakerCompatibility.applyModelCompatibility(prop); - // load default settings from env for (Map.Entry entry : Utils.getenv().entrySet()) { String key = entry.getKey(); @@ -921,6 +918,9 @@ private void loadServingProperties() { arguments.putIfAbsent(key, value); } } + + // Apply SageMaker compatibility for model-level configurations + SageMakerCompatibility.applyModelCompatibility(prop); } }