Skip to content

Commit 7dde753

Browse files
committed
Added support for Sagemaker ENV variables + custom formatter scripts
1 parent 3e49622 commit 7dde753

File tree

13 files changed

+1274
-131
lines changed

13 files changed

+1274
-131
lines changed

engines/python/setup/djl_python/input_parser.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ def input_formatter(function):
3636
return function
3737

3838

39+
def predict_formatter(function):
40+
"""
41+
Decorator for predict_formatter. User just need to annotate @predict_formatter for their custom defined function.
42+
:param function: Decorator takes in the function and adds an attribute.
43+
:return:
44+
"""
45+
# adding an attribute to the function, which is used to find the decorated function.
46+
function.is_predict_formatter = True
47+
return function
48+
49+
50+
def model_loading_formatter(function):
51+
"""
52+
Decorator for model_loading_formatter. User just need to annotate @model_loading_formatter for their custom defined function.
53+
:param function: Decorator takes in the function and adds an attribute.
54+
:return:
55+
"""
56+
# adding an attribute to the function, which is used to find the decorated function.
57+
function.is_model_loading_formatter = True
58+
return function
59+
60+
3961
@dataclass
4062
class ParsedInput:
4163
errors: dict = field(default_factory=lambda: {})

engines/python/setup/djl_python/sklearn_handler.py

Lines changed: 96 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Optional
1919
from djl_python import Input, Output
2020
from djl_python.encode_decode import decode
21-
from djl_python.utils import find_model_file
21+
from djl_python.utils import find_model_file, get_sagemaker_function
2222
from djl_python.service_loader import get_annotated_function
2323
from djl_python.import_utils import joblib, cloudpickle, skops_io as sio
2424

@@ -31,6 +31,49 @@ def __init__(self):
3131
self.custom_input_formatter = None
3232
self.custom_output_formatter = None
3333
self.custom_predict_formatter = None
34+
self.custom_model_loading_formatter = None
35+
self.init_properties = None
36+
self.is_sagemaker_script = False
37+
38+
def _load_custom_formatters(self, model_dir: str):
39+
"""Load custom formatters, checking DJL decorators first, then SageMaker functions."""
40+
# Check for DJL decorator-based custom formatters first
41+
self.custom_model_loading_formatter = get_annotated_function(
42+
model_dir, "is_model_loading_formatter")
43+
self.custom_input_formatter = get_annotated_function(
44+
model_dir, "is_input_formatter")
45+
self.custom_output_formatter = get_annotated_function(
46+
model_dir, "is_output_formatter")
47+
self.custom_predict_formatter = get_annotated_function(
48+
model_dir, "is_predict_formatter")
49+
50+
# If no decorator-based formatters found, check for SageMaker-style formatters
51+
if not any([
52+
self.custom_input_formatter, self.custom_output_formatter,
53+
self.custom_predict_formatter,
54+
self.custom_model_loading_formatter
55+
]):
56+
57+
sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn')
58+
sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn')
59+
sagemaker_predict_fn = get_sagemaker_function(
60+
model_dir, 'predict_fn')
61+
sagemaker_output_fn = get_sagemaker_function(
62+
model_dir, 'output_fn')
63+
64+
if any([
65+
sagemaker_model_fn, sagemaker_input_fn,
66+
sagemaker_predict_fn, sagemaker_output_fn
67+
]):
68+
self.is_sagemaker_script = True
69+
if sagemaker_model_fn:
70+
self.custom_model_loading_formatter = sagemaker_model_fn
71+
if sagemaker_input_fn:
72+
self.custom_input_formatter = sagemaker_input_fn
73+
if sagemaker_predict_fn:
74+
self.custom_predict_formatter = sagemaker_predict_fn
75+
if sagemaker_output_fn:
76+
self.custom_output_formatter = sagemaker_output_fn
3477

3578
def _get_trusted_types(self, properties: dict):
3679
trusted_types_str = properties.get("skops_trusted_types", "")
@@ -46,6 +89,8 @@ def _get_trusted_types(self, properties: dict):
4689
return trusted_types
4790

4891
def initialize(self, properties: dict):
92+
# Store initialization properties for use during inference
93+
self.init_properties = properties.copy()
4994
model_dir = properties.get("model_dir")
5095
model_format = properties.get("model_format", "skops")
5196

@@ -62,43 +107,51 @@ def initialize(self, properties: dict):
62107
f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle"
63108
)
64109

65-
model_file = find_model_file(model_dir, extensions)
66-
if not model_file:
67-
raise FileNotFoundError(
68-
f"No model file found with format '{model_format}' in {model_dir}"
69-
)
110+
# Load custom formatters
111+
self._load_custom_formatters(model_dir)
70112

71-
if model_format == "skops":
72-
trusted_types = self._get_trusted_types(properties)
73-
self.model = sio.load(model_file, trusted=trusted_types)
113+
# Load model
114+
if self.custom_model_loading_formatter:
115+
self.model = self.custom_model_loading_formatter(model_dir)
74116
else:
75-
if properties.get("trust_insecure_model_files",
76-
"false").lower() != "true":
77-
raise ValueError(
78-
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
117+
model_file = find_model_file(model_dir, extensions)
118+
if not model_file:
119+
raise FileNotFoundError(
120+
f"No model file found with format '{model_format}' in {model_dir}"
79121
)
80122

81-
if model_format == "joblib":
82-
self.model = joblib.load(model_file)
83-
elif model_format == "pickle":
84-
with open(model_file, 'rb') as f:
85-
self.model = pickle.load(f)
86-
elif model_format == "cloudpickle":
87-
with open(model_file, 'rb') as f:
88-
self.model = cloudpickle.load(f)
89-
90-
self.custom_input_formatter = get_annotated_function(
91-
model_dir, "is_input_formatter")
92-
self.custom_output_formatter = get_annotated_function(
93-
model_dir, "is_output_formatter")
94-
self.custom_predict_formatter = get_annotated_function(
95-
model_dir, "is_predict_formatter")
123+
if model_format == "skops":
124+
trusted_types = self._get_trusted_types(properties)
125+
self.model = sio.load(model_file, trusted=trusted_types)
126+
else:
127+
if properties.get("trust_insecure_model_files",
128+
"false").lower() != "true":
129+
raise ValueError(
130+
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
131+
)
132+
133+
if model_format == "joblib":
134+
self.model = joblib.load(model_file)
135+
elif model_format == "pickle":
136+
with open(model_file, 'rb') as f:
137+
self.model = pickle.load(f)
138+
elif model_format == "cloudpickle":
139+
with open(model_file, 'rb') as f:
140+
self.model = cloudpickle.load(f)
96141

97142
self.initialized = True
98143

99144
def inference(self, inputs: Input) -> Output:
100145
content_type = inputs.get_property("Content-Type")
101-
accept = inputs.get_property("Accept") or "application/json"
146+
properties = inputs.get_properties()
147+
default_accept = self.init_properties.get("default_accept",
148+
"application/json")
149+
150+
accept = inputs.get_property("Accept")
151+
152+
# If no accept type is specified in the request, use default
153+
if accept == "*/*":
154+
accept = default_accept
102155

103156
# Validate accept type (skip validation if custom output formatter is provided)
104157
if not self.custom_output_formatter:
@@ -112,7 +165,11 @@ def inference(self, inputs: Input) -> Output:
112165
# Input processing
113166
X = None
114167
if self.custom_input_formatter:
115-
X = self.custom_input_formatter(inputs)
168+
if self.is_sagemaker_script:
169+
X = self.custom_input_formatter(inputs.get_as_bytes(),
170+
content_type)
171+
else:
172+
X = self.custom_input_formatter(inputs)
116173
elif "text/csv" in content_type:
117174
X = decode(inputs, content_type, require_csv_headers=False)
118175
else:
@@ -129,17 +186,20 @@ def inference(self, inputs: Input) -> Output:
129186
X = X.reshape(1, -1)
130187

131188
if self.custom_predict_formatter:
132-
predictions = self.custom_predict_formatter(self.model, X)
189+
predictions = self.custom_predict_formatter(X, self.model)
133190
else:
134191
predictions = self.model.predict(X)
135192

136193
# Output processing
137-
if self.custom_output_formatter:
138-
return self.custom_output_formatter(predictions)
139-
140-
# Supports CSV/JSON outputs by default
141194
outputs = Output()
142-
if "text/csv" in accept:
195+
if self.custom_output_formatter:
196+
if self.is_sagemaker_script:
197+
data = self.custom_output_formatter(predictions, accept)
198+
outputs.add_property("Content-Type", accept)
199+
else:
200+
data = self.custom_output_formatter(predictions)
201+
outputs.add(data)
202+
elif "text/csv" in accept:
143203
csv_buffer = StringIO()
144204
np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',')
145205
outputs.add(csv_buffer.getvalue().rstrip())

engines/python/setup/djl_python/tests/test_encode_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_decode_text_csv(self):
115115
mock_decode_csv.return_value = {"inputs": ["test input"]}
116116
result = decode(self.mock_input, "text/csv")
117117

118-
mock_decode_csv.assert_called_once_with(self.mock_input)
118+
mock_decode_csv.assert_called_once_with(self.mock_input,
119+
require_headers=True)
119120
self.assertEqual(result, {"inputs": ["test input"]})
120121

121122
def test_decode_text_plain(self):

engines/python/setup/djl_python/utils.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,21 @@
1313
import glob
1414
import logging
1515
import os
16-
from typing import Optional, List
16+
import inspect
17+
import importlib.util
18+
from typing import Optional, List, Callable
1719

1820
from djl_python import Output
1921
from djl_python.inputs import Input
22+
from djl_python.service_loader import load_model_service, has_function_in_module
23+
24+
# SageMaker function signatures for validation
25+
SAGEMAKER_SIGNATURES = {
26+
'model_fn': ['model_dir'],
27+
'input_fn': ['request_body', 'content_type'],
28+
'predict_fn': ['input_data', 'model'],
29+
'output_fn': ['prediction', 'accept']
30+
}
2031

2132

2233
class IdCounter:
@@ -188,3 +199,59 @@ def find_model_file(model_dir: str, extensions: List[str]) -> Optional[str]:
188199
)
189200

190201
return all_matches[0] if all_matches else None
202+
203+
204+
def _validate_sagemaker_function(
205+
module, func_name: str,
206+
expected_params: List[str]) -> Optional[Callable]:
207+
"""
208+
Validate that function exists and has correct signature
209+
Returns the function if valid, None otherwise
210+
"""
211+
if not hasattr(module, func_name):
212+
return None
213+
214+
func = getattr(module, func_name)
215+
if not callable(func):
216+
return None
217+
218+
try:
219+
sig = inspect.signature(func)
220+
param_names = list(sig.parameters.keys())
221+
222+
# Check parameter count and names match exactly
223+
if param_names == expected_params:
224+
return func
225+
except (ValueError, TypeError):
226+
# Handle cases where signature inspection fails
227+
pass
228+
229+
return None
230+
231+
232+
def get_sagemaker_function(model_dir: str,
233+
func_name: str) -> Optional[Callable]:
234+
"""
235+
Load and validate SageMaker-style formatter function from model.py
236+
237+
:param model_dir: model directory containing model.py
238+
:param func_name: SageMaker function name (model_fn, input_fn, predict_fn, output_fn)
239+
:return: Validated function or None if not found/invalid
240+
"""
241+
242+
if func_name not in SAGEMAKER_SIGNATURES:
243+
return None
244+
245+
try:
246+
service = load_model_service(model_dir, "model.py", -1)
247+
if has_function_in_module(service.module, func_name):
248+
func = getattr(service.module, func_name)
249+
# Optional: validate signature
250+
expected_params = SAGEMAKER_SIGNATURES[func_name]
251+
if _validate_sagemaker_function(service.module, func_name,
252+
expected_params):
253+
return func
254+
255+
except Exception as e:
256+
logging.debug(f"Failed to load {func_name} from model.py: {e}")
257+
return None

0 commit comments

Comments
 (0)