Skip to content

Commit 63554fc

Browse files
authored
Added support for Sagemaker ENV variables + custom formatter scripts (#2921)
1 parent 28ac358 commit 63554fc

File tree

19 files changed

+2547
-177
lines changed

19 files changed

+2547
-177
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ node_modules/
3131

3232
# dir
3333
tests/integration/models/
34+
tests/integration/sagemaker_test_results.txt
3435
engines/python/setup/djl_python/tests/resources*
3536

3637
tests/integration/awscurl
@@ -39,3 +40,4 @@ __pycache__
3940
dist/
4041
*.egg-info/
4142
*.pt
43+

engines/python/setup/djl_python/custom_formatter_handling.py

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
1313
import logging
14+
import os
15+
from dataclasses import dataclass
16+
from typing import Optional, Callable
1417

1518
from djl_python.service_loader import get_annotated_function
19+
from djl_python.utils import get_sagemaker_function
20+
from djl_python.inputs import Input
1621

1722
logger = logging.getLogger(__name__)
1823

@@ -29,39 +34,94 @@ def __init__(self, message: str, original_exception: Exception):
2934
class CustomFormatterHandler:
3035

3136
def __init__(self):
32-
self.output_formatter = None
33-
self.input_formatter = None
37+
self.input_formatter: Optional[Callable] = None
38+
self.output_formatter: Optional[Callable] = None
39+
self.prediction_handler: Optional[Callable] = None
40+
self.init_handler: Optional[Callable] = None
41+
self.is_sagemaker_script: bool = False
3442

3543
def load_formatters(self, model_dir: str):
36-
"""Load custom formatters from model.py"""
44+
"""Load custom formatters/handlers from model.py with SageMaker detection"""
3745
try:
3846
self.input_formatter = get_annotated_function(
3947
model_dir, "is_input_formatter")
4048
self.output_formatter = get_annotated_function(
4149
model_dir, "is_output_formatter")
50+
self.prediction_handler = get_annotated_function(
51+
model_dir, "is_prediction_handler")
52+
self.init_handler = get_annotated_function(model_dir,
53+
"is_init_handler")
54+
55+
# Detect SageMaker script pattern for backward compatibility
56+
self._detect_sagemaker_functions(model_dir)
57+
4258
logger.info(
43-
f"Loaded formatters - input: {self.input_formatter}, output: {self.output_formatter}"
44-
)
59+
f"Loaded formatters - input: {bool(self.input_formatter)}, "
60+
f"output: {bool(self.output_formatter)}")
61+
logger.info(
62+
f"Loaded handlers - prediction: {bool(self.prediction_handler)}, "
63+
f"init: {bool(self.init_handler)}, "
64+
f"sagemaker: {self.is_sagemaker_script}")
4565
except Exception as e:
4666
raise CustomFormatterError(
47-
f"Failed to load custom formatters from {model_dir}", e)
67+
f"Failed to load custom code from {model_dir}", e)
68+
69+
def _detect_sagemaker_functions(self, model_dir: str):
70+
"""Detect and load SageMaker-style functions for backward compatibility"""
71+
# If no decorator-based code found, check for SageMaker functions
72+
if not any([
73+
self.input_formatter, self.output_formatter,
74+
self.prediction_handler, self.init_handler
75+
]):
76+
sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn')
77+
sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn')
78+
sagemaker_predict_fn = get_sagemaker_function(
79+
model_dir, 'predict_fn')
80+
sagemaker_output_fn = get_sagemaker_function(
81+
model_dir, 'output_fn')
82+
83+
if any([
84+
sagemaker_model_fn, sagemaker_input_fn,
85+
sagemaker_predict_fn, sagemaker_output_fn
86+
]):
87+
self.is_sagemaker_script = True
88+
if sagemaker_model_fn:
89+
self.init_handler = sagemaker_model_fn
90+
if sagemaker_input_fn:
91+
self.input_formatter = sagemaker_input_fn
92+
if sagemaker_predict_fn:
93+
self.prediction_handler = sagemaker_predict_fn
94+
if sagemaker_output_fn:
95+
self.output_formatter = sagemaker_output_fn
96+
logger.info("Loaded SageMaker-style functions")
4897

4998
def apply_input_formatter(self, decoded_payload, **kwargs):
5099
"""Apply input formatter if available"""
51-
if self.input_formatter:
100+
if self.input_formatter is not None:
52101
try:
53-
return self.input_formatter(decoded_payload, **kwargs)
102+
if self.is_sagemaker_script:
103+
# SageMaker input_fn expects (data, content_type)
104+
content_type = kwargs.get('content_type')
105+
return self.input_formatter(decoded_payload.get_as_bytes(),
106+
content_type)
107+
else:
108+
return self.input_formatter(decoded_payload, **kwargs)
54109
except Exception as e:
55110
logger.exception("Custom input formatter failed")
56111
raise CustomFormatterError(
57112
"Custom input formatter execution failed", e)
58113
return decoded_payload
59114

60-
def apply_output_formatter(self, output):
115+
def apply_output_formatter(self, output, **kwargs):
61116
"""Apply output formatter if available"""
62-
if self.output_formatter:
117+
if self.output_formatter is not None:
63118
try:
64-
return self.output_formatter(output)
119+
if self.is_sagemaker_script:
120+
# SageMaker output_fn expects (predictions, accept)
121+
accept = kwargs.get('accept')
122+
return self.output_formatter(output, accept)
123+
else:
124+
return self.output_formatter(output)
65125
except Exception as e:
66126
logger.exception("Custom output formatter failed")
67127
raise CustomFormatterError(
@@ -79,3 +139,25 @@ async def apply_output_formatter_streaming_raw(self, stream_generator):
79139
logger.exception("Streaming formatter failed")
80140
raise CustomFormatterError(
81141
"Custom streaming formatter execution failed", e)
142+
143+
def apply_init_handler(self, model_dir, **kwargs):
144+
"""Apply custom init handler if available"""
145+
if self.init_handler is not None:
146+
try:
147+
return self.init_handler(model_dir, **kwargs)
148+
except Exception as e:
149+
logger.exception("Custom init handler failed")
150+
raise CustomFormatterError(
151+
"Custom init handler execution failed", e)
152+
return None
153+
154+
def apply_prediction_handler(self, X, model, **kwargs):
155+
"""Apply custom prediction handler if available"""
156+
if self.prediction_handler is not None:
157+
try:
158+
return self.prediction_handler(X, model, **kwargs)
159+
except Exception as e:
160+
logger.exception("Custom prediction handler failed")
161+
raise CustomFormatterError(
162+
"Custom prediction handler execution failed", e)
163+
return None

engines/python/setup/djl_python/input_parser.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ def input_formatter(function):
3636
return function
3737

3838

39+
def prediction_handler(function):
40+
"""
41+
Decorator for prediction_handler. User just need to annotate @prediction_handler for their custom defined function.
42+
:param function: Decorator takes in the function and adds an attribute.
43+
:return:
44+
"""
45+
# adding an attribute to the function, which is used to find the decorated function.
46+
function.is_prediction_handler = True
47+
return function
48+
49+
50+
def init_handler(function):
51+
"""
52+
Decorator for init_handler. User just need to annotate @init_handler for their custom defined function.
53+
:param function: Decorator takes in the function and adds an attribute.
54+
:return:
55+
"""
56+
# adding an attribute to the function, which is used to find the decorated function.
57+
function.is_init_handler = True
58+
return function
59+
60+
3961
@dataclass
4062
class ParsedInput:
4163
errors: dict = field(default_factory=lambda: {})

engines/python/setup/djl_python/sklearn_handler.py

Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,17 @@
1919
from djl_python import Input, Output
2020
from djl_python.encode_decode import decode
2121
from djl_python.utils import find_model_file
22-
from djl_python.service_loader import get_annotated_function
22+
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
2323
from djl_python.import_utils import joblib, cloudpickle, skops_io as sio
2424

2525

26-
class SklearnHandler:
26+
class SklearnHandler(CustomFormatterHandler):
2727

2828
def __init__(self):
29+
super().__init__()
2930
self.model = None
3031
self.initialized = False
31-
self.custom_input_formatter = None
32-
self.custom_output_formatter = None
33-
self.custom_predict_formatter = None
32+
self.init_properties = None
3433

3534
def _get_trusted_types(self, properties: dict):
3635
trusted_types_str = properties.get("skops_trusted_types", "")
@@ -46,6 +45,8 @@ def _get_trusted_types(self, properties: dict):
4645
return trusted_types
4746

4847
def initialize(self, properties: dict):
48+
# Store initialization properties for use during inference
49+
self.init_properties = properties.copy()
4950
model_dir = properties.get("model_dir")
5051
model_format = properties.get("model_format", "skops")
5152

@@ -62,64 +63,74 @@ def initialize(self, properties: dict):
6263
f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle"
6364
)
6465

65-
model_file = find_model_file(model_dir, extensions)
66-
if not model_file:
67-
raise FileNotFoundError(
68-
f"No model file found with format '{model_format}' in {model_dir}"
69-
)
66+
# Load custom code
67+
try:
68+
self.load_formatters(model_dir)
69+
except CustomFormatterError as e:
70+
raise
7071

71-
if model_format == "skops":
72-
trusted_types = self._get_trusted_types(properties)
73-
self.model = sio.load(model_file, trusted=trusted_types)
72+
# Load model
73+
init_result = self.apply_init_handler(model_dir)
74+
if init_result is not None:
75+
self.model = init_result
7476
else:
75-
if properties.get("trust_insecure_model_files",
76-
"false").lower() != "true":
77-
raise ValueError(
78-
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
77+
model_file = find_model_file(model_dir, extensions)
78+
if not model_file:
79+
raise FileNotFoundError(
80+
f"No model file found with format '{model_format}' in {model_dir}"
7981
)
8082

81-
if model_format == "joblib":
82-
self.model = joblib.load(model_file)
83-
elif model_format == "pickle":
84-
with open(model_file, 'rb') as f:
85-
self.model = pickle.load(f)
86-
elif model_format == "cloudpickle":
87-
with open(model_file, 'rb') as f:
88-
self.model = cloudpickle.load(f)
89-
90-
self.custom_input_formatter = get_annotated_function(
91-
model_dir, "is_input_formatter")
92-
self.custom_output_formatter = get_annotated_function(
93-
model_dir, "is_output_formatter")
94-
self.custom_predict_formatter = get_annotated_function(
95-
model_dir, "is_predict_formatter")
83+
if model_format == "skops":
84+
trusted_types = self._get_trusted_types(properties)
85+
self.model = sio.load(model_file, trusted=trusted_types)
86+
else:
87+
if properties.get("trust_insecure_model_files",
88+
"false").lower() != "true":
89+
raise ValueError(
90+
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
91+
)
92+
93+
if model_format == "joblib":
94+
self.model = joblib.load(model_file)
95+
elif model_format == "pickle":
96+
with open(model_file, 'rb') as f:
97+
self.model = pickle.load(f)
98+
elif model_format == "cloudpickle":
99+
with open(model_file, 'rb') as f:
100+
self.model = cloudpickle.load(f)
96101

97102
self.initialized = True
98103

99104
def inference(self, inputs: Input) -> Output:
100105
content_type = inputs.get_property("Content-Type")
101-
accept = inputs.get_property("Accept") or "application/json"
106+
properties = inputs.get_properties()
107+
default_accept = self.init_properties.get("default_accept",
108+
"application/json")
109+
110+
accept = inputs.get_property("Accept")
111+
112+
# If no accept type is specified in the request, use default
113+
if accept == "*/*":
114+
accept = default_accept
102115

103116
# Validate accept type (skip validation if custom output formatter is provided)
104-
if not self.custom_output_formatter:
105-
supported_accept_types = ["application/json", "text/csv"]
117+
if self.output_formatter is None: # No formatter available
106118
if not any(supported_type in accept
107-
for supported_type in supported_accept_types):
119+
for supported_type in ["application/json", "text/csv"]):
108120
raise ValueError(
109-
f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}"
121+
f"Unsupported Accept type: {accept}. Supported types: application/json, text/csv"
110122
)
111123

112124
# Input processing
113-
X = None
114-
if self.custom_input_formatter:
115-
X = self.custom_input_formatter(inputs)
116-
elif "text/csv" in content_type:
117-
X = decode(inputs, content_type, require_csv_headers=False)
118-
else:
119-
input_map = decode(inputs, content_type)
120-
data = input_map.get("inputs") if isinstance(input_map,
121-
dict) else input_map
122-
X = np.array(data)
125+
X = self.apply_input_formatter(inputs, content_type=content_type)
126+
if X is inputs: # No formatter applied
127+
if "text/csv" in content_type:
128+
X = decode(inputs, content_type, require_csv_headers=False)
129+
else:
130+
input_map = decode(inputs, content_type)
131+
data = input_map.get("inputs") if isinstance(
132+
input_map, dict) else input_map
133+
X = np.array(data)
123134

124135
if X is None or not hasattr(X, 'ndim'):
125136
raise ValueError(
@@ -128,18 +139,19 @@ def inference(self, inputs: Input) -> Output:
128139
if X.ndim == 1:
129140
X = X.reshape(1, -1)
130141

131-
if self.custom_predict_formatter:
132-
predictions = self.custom_predict_formatter(self.model, X)
133-
else:
142+
predictions = self.apply_prediction_handler(X, self.model)
143+
if predictions is None:
134144
predictions = self.model.predict(X)
135145

136146
# Output processing
137-
if self.custom_output_formatter:
138-
return self.custom_output_formatter(predictions)
139-
140-
# Supports CSV/JSON outputs by default
141147
outputs = Output()
142-
if "text/csv" in accept:
148+
formatted_output = self.apply_output_formatter(predictions,
149+
accept=accept)
150+
if formatted_output is not predictions: # Formatter was applied
151+
if self.is_sagemaker_script:
152+
outputs.add_property("Content-Type", accept)
153+
outputs.add(formatted_output)
154+
elif "text/csv" in accept:
143155
csv_buffer = StringIO()
144156
np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',')
145157
outputs.add(csv_buffer.getvalue().rstrip())

engines/python/setup/djl_python/tests/test_encode_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_decode_text_csv(self):
115115
mock_decode_csv.return_value = {"inputs": ["test input"]}
116116
result = decode(self.mock_input, "text/csv")
117117

118-
mock_decode_csv.assert_called_once_with(self.mock_input)
118+
mock_decode_csv.assert_called_once_with(self.mock_input,
119+
require_headers=True)
119120
self.assertEqual(result, {"inputs": ["test input"]})
120121

121122
def test_decode_text_plain(self):

0 commit comments

Comments
 (0)