Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ node_modules/

# dir
tests/integration/models/
tests/integration/sagemaker_test_results.txt
engines/python/setup/djl_python/tests/resources*

tests/integration/awscurl
Expand All @@ -39,3 +40,4 @@ __pycache__
dist/
*.egg-info/
*.pt

104 changes: 93 additions & 11 deletions engines/python/setup/djl_python/custom_formatter_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
import os
from dataclasses import dataclass
from typing import Optional, Callable

from djl_python.service_loader import get_annotated_function
from djl_python.utils import get_sagemaker_function
from djl_python.inputs import Input

logger = logging.getLogger(__name__)

Expand All @@ -29,39 +34,94 @@ def __init__(self, message: str, original_exception: Exception):
class CustomFormatterHandler:

def __init__(self):
self.output_formatter = None
self.input_formatter = None
self.input_formatter: Optional[Callable] = None
self.output_formatter: Optional[Callable] = None
self.prediction_handler: Optional[Callable] = None
self.init_handler: Optional[Callable] = None
self.is_sagemaker_script: bool = False

def load_formatters(self, model_dir: str):
"""Load custom formatters from model.py"""
"""Load custom formatters/handlers from model.py with SageMaker detection"""
try:
self.input_formatter = get_annotated_function(
model_dir, "is_input_formatter")
self.output_formatter = get_annotated_function(
model_dir, "is_output_formatter")
self.prediction_handler = get_annotated_function(
model_dir, "is_prediction_handler")
self.init_handler = get_annotated_function(model_dir,
"is_init_handler")

# Detect SageMaker script pattern for backward compatibility
self._detect_sagemaker_functions(model_dir)

logger.info(
f"Loaded formatters - input: {self.input_formatter}, output: {self.output_formatter}"
)
f"Loaded formatters - input: {bool(self.input_formatter)}, "
f"output: {bool(self.output_formatter)}")
logger.info(
f"Loaded handlers - prediction: {bool(self.prediction_handler)}, "
f"init: {bool(self.init_handler)}, "
f"sagemaker: {self.is_sagemaker_script}")
except Exception as e:
raise CustomFormatterError(
f"Failed to load custom formatters from {model_dir}", e)
f"Failed to load custom code from {model_dir}", e)

def _detect_sagemaker_functions(self, model_dir: str):
"""Detect and load SageMaker-style functions for backward compatibility"""
# If no decorator-based code found, check for SageMaker functions
if not any([
self.input_formatter, self.output_formatter,
self.prediction_handler, self.init_handler
]):
sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn')
sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn')
sagemaker_predict_fn = get_sagemaker_function(
model_dir, 'predict_fn')
sagemaker_output_fn = get_sagemaker_function(
model_dir, 'output_fn')

if any([
sagemaker_model_fn, sagemaker_input_fn,
sagemaker_predict_fn, sagemaker_output_fn
]):
self.is_sagemaker_script = True
if sagemaker_model_fn:
self.init_handler = sagemaker_model_fn
if sagemaker_input_fn:
self.input_formatter = sagemaker_input_fn
if sagemaker_predict_fn:
self.prediction_handler = sagemaker_predict_fn
if sagemaker_output_fn:
self.output_formatter = sagemaker_output_fn
logger.info("Loaded SageMaker-style functions")

def apply_input_formatter(self, decoded_payload, **kwargs):
"""Apply input formatter if available"""
if self.input_formatter:
if self.input_formatter is not None:
try:
return self.input_formatter(decoded_payload, **kwargs)
if self.is_sagemaker_script:
# SageMaker input_fn expects (data, content_type)
content_type = kwargs.get('content_type')
return self.input_formatter(decoded_payload.get_as_bytes(),
content_type)
else:
return self.input_formatter(decoded_payload, **kwargs)
except Exception as e:
logger.exception("Custom input formatter failed")
raise CustomFormatterError(
"Custom input formatter execution failed", e)
return decoded_payload

def apply_output_formatter(self, output):
def apply_output_formatter(self, output, **kwargs):
"""Apply output formatter if available"""
if self.output_formatter:
if self.output_formatter is not None:
try:
return self.output_formatter(output)
if self.is_sagemaker_script:
# SageMaker output_fn expects (predictions, accept)
accept = kwargs.get('accept')
return self.output_formatter(output, accept)
else:
return self.output_formatter(output)
except Exception as e:
logger.exception("Custom output formatter failed")
raise CustomFormatterError(
Expand All @@ -79,3 +139,25 @@ async def apply_output_formatter_streaming_raw(self, stream_generator):
logger.exception("Streaming formatter failed")
raise CustomFormatterError(
"Custom streaming formatter execution failed", e)

def apply_init_handler(self, model_dir, **kwargs):
"""Apply custom init handler if available"""
if self.init_handler is not None:
try:
return self.init_handler(model_dir, **kwargs)
except Exception as e:
logger.exception("Custom init handler failed")
raise CustomFormatterError(
"Custom init handler execution failed", e)
return None

def apply_prediction_handler(self, X, model, **kwargs):
"""Apply custom prediction handler if available"""
if self.prediction_handler is not None:
try:
return self.prediction_handler(X, model, **kwargs)
except Exception as e:
logger.exception("Custom prediction handler failed")
raise CustomFormatterError(
"Custom prediction handler execution failed", e)
return None
22 changes: 22 additions & 0 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@ def input_formatter(function):
return function


def prediction_handler(function):
"""
Decorator for prediction_handler. User just need to annotate @prediction_handler for their custom defined function.
:param function: Decorator takes in the function and adds an attribute.
:return:
"""
# adding an attribute to the function, which is used to find the decorated function.
function.is_prediction_handler = True
return function


def init_handler(function):
"""
Decorator for init_handler. User just need to annotate @init_handler for their custom defined function.
:param function: Decorator takes in the function and adds an attribute.
:return:
"""
# adding an attribute to the function, which is used to find the decorated function.
function.is_init_handler = True
return function


@dataclass
class ParsedInput:
errors: dict = field(default_factory=lambda: {})
Expand Down
122 changes: 67 additions & 55 deletions engines/python/setup/djl_python/sklearn_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@
from djl_python import Input, Output
from djl_python.encode_decode import decode
from djl_python.utils import find_model_file
from djl_python.service_loader import get_annotated_function
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
from djl_python.import_utils import joblib, cloudpickle, skops_io as sio


class SklearnHandler:
class SklearnHandler(CustomFormatterHandler):

def __init__(self):
super().__init__()
self.model = None
self.initialized = False
self.custom_input_formatter = None
self.custom_output_formatter = None
self.custom_predict_formatter = None
self.init_properties = None

def _get_trusted_types(self, properties: dict):
trusted_types_str = properties.get("skops_trusted_types", "")
Expand All @@ -46,6 +45,8 @@ def _get_trusted_types(self, properties: dict):
return trusted_types

def initialize(self, properties: dict):
# Store initialization properties for use during inference
self.init_properties = properties.copy()
model_dir = properties.get("model_dir")
model_format = properties.get("model_format", "skops")

Expand All @@ -62,64 +63,74 @@ def initialize(self, properties: dict):
f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle"
)

model_file = find_model_file(model_dir, extensions)
if not model_file:
raise FileNotFoundError(
f"No model file found with format '{model_format}' in {model_dir}"
)
# Load custom code
try:
self.load_formatters(model_dir)
except CustomFormatterError as e:
raise

if model_format == "skops":
trusted_types = self._get_trusted_types(properties)
self.model = sio.load(model_file, trusted=trusted_types)
# Load model
init_result = self.apply_init_handler(model_dir)
if init_result is not None:
self.model = init_result
else:
if properties.get("trust_insecure_model_files",
"false").lower() != "true":
raise ValueError(
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
model_file = find_model_file(model_dir, extensions)
if not model_file:
raise FileNotFoundError(
f"No model file found with format '{model_format}' in {model_dir}"
)

if model_format == "joblib":
self.model = joblib.load(model_file)
elif model_format == "pickle":
with open(model_file, 'rb') as f:
self.model = pickle.load(f)
elif model_format == "cloudpickle":
with open(model_file, 'rb') as f:
self.model = cloudpickle.load(f)

self.custom_input_formatter = get_annotated_function(
model_dir, "is_input_formatter")
self.custom_output_formatter = get_annotated_function(
model_dir, "is_output_formatter")
self.custom_predict_formatter = get_annotated_function(
model_dir, "is_predict_formatter")
if model_format == "skops":
trusted_types = self._get_trusted_types(properties)
self.model = sio.load(model_file, trusted=trusted_types)
else:
if properties.get("trust_insecure_model_files",
"false").lower() != "true":
raise ValueError(
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
)

if model_format == "joblib":
self.model = joblib.load(model_file)
elif model_format == "pickle":
with open(model_file, 'rb') as f:
self.model = pickle.load(f)
elif model_format == "cloudpickle":
with open(model_file, 'rb') as f:
self.model = cloudpickle.load(f)

self.initialized = True

def inference(self, inputs: Input) -> Output:
content_type = inputs.get_property("Content-Type")
accept = inputs.get_property("Accept") or "application/json"
properties = inputs.get_properties()
default_accept = self.init_properties.get("default_accept",
"application/json")

accept = inputs.get_property("Accept")

# If no accept type is specified in the request, use default
if accept == "*/*":
accept = default_accept

# Validate accept type (skip validation if custom output formatter is provided)
if not self.custom_output_formatter:
supported_accept_types = ["application/json", "text/csv"]
if self.output_formatter is None: # No formatter available
if not any(supported_type in accept
for supported_type in supported_accept_types):
for supported_type in ["application/json", "text/csv"]):
raise ValueError(
f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}"
f"Unsupported Accept type: {accept}. Supported types: application/json, text/csv"
)

# Input processing
X = None
if self.custom_input_formatter:
X = self.custom_input_formatter(inputs)
elif "text/csv" in content_type:
X = decode(inputs, content_type, require_csv_headers=False)
else:
input_map = decode(inputs, content_type)
data = input_map.get("inputs") if isinstance(input_map,
dict) else input_map
X = np.array(data)
X = self.apply_input_formatter(inputs, content_type=content_type)
if X is inputs: # No formatter applied
if "text/csv" in content_type:
X = decode(inputs, content_type, require_csv_headers=False)
else:
input_map = decode(inputs, content_type)
data = input_map.get("inputs") if isinstance(
input_map, dict) else input_map
X = np.array(data)

if X is None or not hasattr(X, 'ndim'):
raise ValueError(
Expand All @@ -128,18 +139,19 @@ def inference(self, inputs: Input) -> Output:
if X.ndim == 1:
X = X.reshape(1, -1)

if self.custom_predict_formatter:
predictions = self.custom_predict_formatter(self.model, X)
else:
predictions = self.apply_prediction_handler(X, self.model)
if predictions is None:
predictions = self.model.predict(X)

# Output processing
if self.custom_output_formatter:
return self.custom_output_formatter(predictions)

# Supports CSV/JSON outputs by default
outputs = Output()
if "text/csv" in accept:
formatted_output = self.apply_output_formatter(predictions,
accept=accept)
if formatted_output is not predictions: # Formatter was applied
if self.is_sagemaker_script:
outputs.add_property("Content-Type", accept)
outputs.add(formatted_output)
elif "text/csv" in accept:
csv_buffer = StringIO()
np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',')
outputs.add(csv_buffer.getvalue().rstrip())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def test_decode_text_csv(self):
mock_decode_csv.return_value = {"inputs": ["test input"]}
result = decode(self.mock_input, "text/csv")

mock_decode_csv.assert_called_once_with(self.mock_input)
mock_decode_csv.assert_called_once_with(self.mock_input,
require_headers=True)
self.assertEqual(result, {"inputs": ["test input"]})

def test_decode_text_plain(self):
Expand Down
Loading
Loading