Skip to content

Commit c0a0d4a

Browse files
committed
Encapsulated custom handler/formatter logic and added SageMaker endpoint integration tests
1 parent 7dde753 commit c0a0d4a

File tree

10 files changed

+667
-162
lines changed

10 files changed

+667
-162
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ node_modules/
3131

3232
# dir
3333
tests/integration/models/
34+
tests/integration/sagemaker_test_results.txt
3435
engines/python/setup/djl_python/tests/resources*
3536

3637
tests/integration/awscurl
@@ -39,3 +40,4 @@ __pycache__
3940
dist/
4041
*.egg-info/
4142
*.pt
43+

engines/python/setup/djl_python/custom_formatter_handling.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
1313
import logging
14+
import os
15+
from dataclasses import dataclass
16+
from typing import Optional, Callable
1417

1518
from djl_python.service_loader import get_annotated_function
19+
from djl_python.utils import get_sagemaker_function
1620

1721
logger = logging.getLogger(__name__)
1822

@@ -26,31 +30,103 @@ def __init__(self, message: str, original_exception: Exception):
2630
self.__cause__ = original_exception
2731

2832

33+
@dataclass
34+
class CustomFormatters:
35+
"""Container for input/output formatting functions"""
36+
input_formatter: Optional[Callable] = None
37+
output_formatter: Optional[Callable] = None
38+
39+
40+
@dataclass
41+
class CustomHandlers:
42+
"""Container for prediction/initialization handler functions"""
43+
prediction_handler: Optional[Callable] = None
44+
init_handler: Optional[Callable] = None
45+
46+
47+
@dataclass
48+
class CustomCode:
49+
"""Container for all custom formatters and handlers"""
50+
formatters: CustomFormatters
51+
handlers: CustomHandlers
52+
is_sagemaker_script: bool = False
53+
54+
def __init__(self):
55+
self.formatters = CustomFormatters()
56+
self.handlers = CustomHandlers()
57+
self.is_sagemaker_script = False
58+
59+
2960
class CustomFormatterHandler:
3061

3162
def __init__(self):
32-
self.output_formatter = None
33-
self.input_formatter = None
63+
self.custom_code = CustomCode()
3464

35-
def load_formatters(self, model_dir: str):
36-
"""Load custom formatters from model.py"""
65+
def load_formatters(self, model_dir: str) -> CustomCode:
66+
"""Load custom formatters/handlers from model.py with SageMaker detection"""
3767
try:
38-
self.input_formatter = get_annotated_function(
68+
self.custom_code.formatters.input_formatter = get_annotated_function(
3969
model_dir, "is_input_formatter")
40-
self.output_formatter = get_annotated_function(
70+
self.custom_code.formatters.output_formatter = get_annotated_function(
4171
model_dir, "is_output_formatter")
72+
self.custom_code.handlers.prediction_handler = get_annotated_function(
73+
model_dir, "is_prediction_handler")
74+
self.custom_code.handlers.init_handler = get_annotated_function(
75+
model_dir, "is_init_handler")
76+
77+
# Detect SageMaker script pattern for backward compatibility
78+
self._detect_sagemaker_functions(model_dir)
79+
4280
logger.info(
43-
f"Loaded formatters - input: {self.input_formatter}, output: {self.output_formatter}"
81+
f"Loaded formatters - input: {bool(self.custom_code.formatters.input_formatter)}, "
82+
f"output: {bool(self.custom_code.formatters.output_formatter)}"
4483
)
84+
logger.info(
85+
f"Loaded handlers - prediction: {bool(self.custom_code.handlers.prediction_handler)}, "
86+
f"init: {bool(self.custom_code.handlers.init_handler)}, "
87+
f"sagemaker: {self.custom_code.is_sagemaker_script}")
88+
return self.custom_code
4589
except Exception as e:
4690
raise CustomFormatterError(
47-
f"Failed to load custom formatters from {model_dir}", e)
91+
f"Failed to load custom code from {model_dir}", e)
92+
93+
def _detect_sagemaker_functions(self, model_dir: str):
94+
"""Detect and load SageMaker-style functions for backward compatibility"""
95+
# If no decorator-based code found, check for SageMaker functions
96+
if not any([
97+
self.custom_code.formatters.input_formatter,
98+
self.custom_code.formatters.output_formatter,
99+
self.custom_code.handlers.prediction_handler,
100+
self.custom_code.handlers.init_handler
101+
]):
102+
sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn')
103+
sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn')
104+
sagemaker_predict_fn = get_sagemaker_function(
105+
model_dir, 'predict_fn')
106+
sagemaker_output_fn = get_sagemaker_function(
107+
model_dir, 'output_fn')
108+
109+
if any([
110+
sagemaker_model_fn, sagemaker_input_fn,
111+
sagemaker_predict_fn, sagemaker_output_fn
112+
]):
113+
self.custom_code.is_sagemaker_script = True
114+
if sagemaker_model_fn:
115+
self.custom_code.handlers.init_handler = sagemaker_model_fn
116+
if sagemaker_input_fn:
117+
self.custom_code.formatters.input_formatter = sagemaker_input_fn
118+
if sagemaker_predict_fn:
119+
self.custom_code.handlers.prediction_handler = sagemaker_predict_fn
120+
if sagemaker_output_fn:
121+
self.custom_code.formatters.output_formatter = sagemaker_output_fn
122+
logger.info("Loaded SageMaker-style functions")
48123

49124
def apply_input_formatter(self, decoded_payload, **kwargs):
50125
"""Apply input formatter if available"""
51-
if self.input_formatter:
126+
if self.custom_code.formatters.input_formatter:
52127
try:
53-
return self.input_formatter(decoded_payload, **kwargs)
128+
return self.custom_code.formatters.input_formatter(
129+
decoded_payload, **kwargs)
54130
except Exception as e:
55131
logger.exception("Custom input formatter failed")
56132
raise CustomFormatterError(
@@ -59,9 +135,9 @@ def apply_input_formatter(self, decoded_payload, **kwargs):
59135

60136
def apply_output_formatter(self, output):
61137
"""Apply output formatter if available"""
62-
if self.output_formatter:
138+
if self.custom_code.formatters.output_formatter:
63139
try:
64-
return self.output_formatter(output)
140+
return self.custom_code.formatters.output_formatter(output)
65141
except Exception as e:
66142
logger.exception("Custom output formatter failed")
67143
raise CustomFormatterError(
@@ -79,3 +155,9 @@ async def apply_output_formatter_streaming_raw(self, stream_generator):
79155
logger.exception("Streaming formatter failed")
80156
raise CustomFormatterError(
81157
"Custom streaming formatter execution failed", e)
158+
159+
160+
def load_custom_code(model_dir: str) -> CustomCode:
161+
"""Load custom code, checking DJL decorators first, then SageMaker functions"""
162+
handler = CustomFormatterHandler()
163+
return handler.load_formatters(model_dir)

engines/python/setup/djl_python/input_parser.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,25 @@ def input_formatter(function):
3636
return function
3737

3838

39-
def predict_formatter(function):
39+
def prediction_handler(function):
4040
"""
41-
Decorator for predict_formatter. User just need to annotate @predict_formatter for their custom defined function.
41+
Decorator for prediction_handler. User just need to annotate @prediction_handler for their custom defined function.
4242
:param function: Decorator takes in the function and adds an attribute.
4343
:return:
4444
"""
4545
# adding an attribute to the function, which is used to find the decorated function.
46-
function.is_predict_formatter = True
46+
function.is_prediction_handler = True
4747
return function
4848

4949

50-
def model_loading_formatter(function):
50+
def init_handler(function):
5151
"""
52-
Decorator for model_loading_formatter. User just need to annotate @model_loading_formatter for their custom defined function.
52+
Decorator for init_handler. User just need to annotate @init_handler for their custom defined function.
5353
:param function: Decorator takes in the function and adds an attribute.
5454
:return:
5555
"""
5656
# adding an attribute to the function, which is used to find the decorated function.
57-
function.is_model_loading_formatter = True
57+
function.is_init_handler = True
5858
return function
5959

6060

engines/python/setup/djl_python/sklearn_handler.py

Lines changed: 22 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from typing import Optional
1919
from djl_python import Input, Output
2020
from djl_python.encode_decode import decode
21-
from djl_python.utils import find_model_file, get_sagemaker_function
22-
from djl_python.service_loader import get_annotated_function
21+
from djl_python.utils import find_model_file
22+
from djl_python.custom_formatter_handling import load_custom_code
2323
from djl_python.import_utils import joblib, cloudpickle, skops_io as sio
2424

2525

@@ -28,52 +28,8 @@ class SklearnHandler:
2828
def __init__(self):
2929
self.model = None
3030
self.initialized = False
31-
self.custom_input_formatter = None
32-
self.custom_output_formatter = None
33-
self.custom_predict_formatter = None
34-
self.custom_model_loading_formatter = None
31+
self.custom_code = None
3532
self.init_properties = None
36-
self.is_sagemaker_script = False
37-
38-
def _load_custom_formatters(self, model_dir: str):
39-
"""Load custom formatters, checking DJL decorators first, then SageMaker functions."""
40-
# Check for DJL decorator-based custom formatters first
41-
self.custom_model_loading_formatter = get_annotated_function(
42-
model_dir, "is_model_loading_formatter")
43-
self.custom_input_formatter = get_annotated_function(
44-
model_dir, "is_input_formatter")
45-
self.custom_output_formatter = get_annotated_function(
46-
model_dir, "is_output_formatter")
47-
self.custom_predict_formatter = get_annotated_function(
48-
model_dir, "is_predict_formatter")
49-
50-
# If no decorator-based formatters found, check for SageMaker-style formatters
51-
if not any([
52-
self.custom_input_formatter, self.custom_output_formatter,
53-
self.custom_predict_formatter,
54-
self.custom_model_loading_formatter
55-
]):
56-
57-
sagemaker_model_fn = get_sagemaker_function(model_dir, 'model_fn')
58-
sagemaker_input_fn = get_sagemaker_function(model_dir, 'input_fn')
59-
sagemaker_predict_fn = get_sagemaker_function(
60-
model_dir, 'predict_fn')
61-
sagemaker_output_fn = get_sagemaker_function(
62-
model_dir, 'output_fn')
63-
64-
if any([
65-
sagemaker_model_fn, sagemaker_input_fn,
66-
sagemaker_predict_fn, sagemaker_output_fn
67-
]):
68-
self.is_sagemaker_script = True
69-
if sagemaker_model_fn:
70-
self.custom_model_loading_formatter = sagemaker_model_fn
71-
if sagemaker_input_fn:
72-
self.custom_input_formatter = sagemaker_input_fn
73-
if sagemaker_predict_fn:
74-
self.custom_predict_formatter = sagemaker_predict_fn
75-
if sagemaker_output_fn:
76-
self.custom_output_formatter = sagemaker_output_fn
7733

7834
def _get_trusted_types(self, properties: dict):
7935
trusted_types_str = properties.get("skops_trusted_types", "")
@@ -107,12 +63,12 @@ def initialize(self, properties: dict):
10763
f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle"
10864
)
10965

110-
# Load custom formatters
111-
self._load_custom_formatters(model_dir)
66+
# Load custom code
67+
self.custom_code = load_custom_code(model_dir)
11268

11369
# Load model
114-
if self.custom_model_loading_formatter:
115-
self.model = self.custom_model_loading_formatter(model_dir)
70+
if self.custom_code.handlers.init_handler:
71+
self.model = self.custom_code.handlers.init_handler(model_dir)
11672
else:
11773
model_file = find_model_file(model_dir, extensions)
11874
if not model_file:
@@ -154,7 +110,7 @@ def inference(self, inputs: Input) -> Output:
154110
accept = default_accept
155111

156112
# Validate accept type (skip validation if custom output formatter is provided)
157-
if not self.custom_output_formatter:
113+
if not self.custom_code.formatters.output_formatter:
158114
supported_accept_types = ["application/json", "text/csv"]
159115
if not any(supported_type in accept
160116
for supported_type in supported_accept_types):
@@ -164,12 +120,12 @@ def inference(self, inputs: Input) -> Output:
164120

165121
# Input processing
166122
X = None
167-
if self.custom_input_formatter:
168-
if self.is_sagemaker_script:
169-
X = self.custom_input_formatter(inputs.get_as_bytes(),
170-
content_type)
123+
if self.custom_code.formatters.input_formatter:
124+
if self.custom_code.is_sagemaker_script:
125+
X = self.custom_code.formatters.input_formatter(
126+
inputs.get_as_bytes(), content_type)
171127
else:
172-
X = self.custom_input_formatter(inputs)
128+
X = self.custom_code.formatters.input_formatter(inputs)
173129
elif "text/csv" in content_type:
174130
X = decode(inputs, content_type, require_csv_headers=False)
175131
else:
@@ -185,19 +141,22 @@ def inference(self, inputs: Input) -> Output:
185141
if X.ndim == 1:
186142
X = X.reshape(1, -1)
187143

188-
if self.custom_predict_formatter:
189-
predictions = self.custom_predict_formatter(X, self.model)
144+
if self.custom_code.handlers.prediction_handler:
145+
predictions = self.custom_code.handlers.prediction_handler(
146+
X, self.model)
190147
else:
191148
predictions = self.model.predict(X)
192149

193150
# Output processing
194151
outputs = Output()
195-
if self.custom_output_formatter:
196-
if self.is_sagemaker_script:
197-
data = self.custom_output_formatter(predictions, accept)
152+
if self.custom_code.formatters.output_formatter:
153+
if self.custom_code.is_sagemaker_script:
154+
data = self.custom_code.formatters.output_formatter(
155+
predictions, accept)
198156
outputs.add_property("Content-Type", accept)
199157
else:
200-
data = self.custom_output_formatter(predictions)
158+
data = self.custom_code.formatters.output_formatter(
159+
predictions)
201160
outputs.add(data)
202161
elif "text/csv" in accept:
203162
csv_buffer = StringIO()

0 commit comments

Comments
 (0)