1111# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212# the specific language governing permissions and limitations under the License.
1313import logging
14+ import os
15+ from dataclasses import dataclass
16+ from typing import Optional , Callable
1417
1518from djl_python .service_loader import get_annotated_function
19+ from djl_python .utils import get_sagemaker_function
1620
1721logger = logging .getLogger (__name__ )
1822
@@ -26,31 +30,103 @@ def __init__(self, message: str, original_exception: Exception):
2630 self .__cause__ = original_exception
2731
2832
33+ @dataclass
34+ class CustomFormatters :
35+ """Container for input/output formatting functions"""
36+ input_formatter : Optional [Callable ] = None
37+ output_formatter : Optional [Callable ] = None
38+
39+
40+ @dataclass
41+ class CustomHandlers :
42+ """Container for prediction/initialization handler functions"""
43+ prediction_handler : Optional [Callable ] = None
44+ init_handler : Optional [Callable ] = None
45+
46+
47+ @dataclass
48+ class CustomCode :
49+ """Container for all custom formatters and handlers"""
50+ formatters : CustomFormatters
51+ handlers : CustomHandlers
52+ is_sagemaker_script : bool = False
53+
54+ def __init__ (self ):
55+ self .formatters = CustomFormatters ()
56+ self .handlers = CustomHandlers ()
57+ self .is_sagemaker_script = False
58+
59+
2960class CustomFormatterHandler :
3061
3162 def __init__ (self ):
32- self .output_formatter = None
33- self .input_formatter = None
63+ self .custom_code = CustomCode ()
3464
35- def load_formatters (self , model_dir : str ):
36- """Load custom formatters from model.py"""
65+ def load_formatters (self , model_dir : str ) -> CustomCode :
66+ """Load custom formatters/handlers from model.py with SageMaker detection """
3767 try :
38- self .input_formatter = get_annotated_function (
68+ self .custom_code . formatters . input_formatter = get_annotated_function (
3969 model_dir , "is_input_formatter" )
40- self .output_formatter = get_annotated_function (
70+ self .custom_code . formatters . output_formatter = get_annotated_function (
4171 model_dir , "is_output_formatter" )
72+ self .custom_code .handlers .prediction_handler = get_annotated_function (
73+ model_dir , "is_prediction_handler" )
74+ self .custom_code .handlers .init_handler = get_annotated_function (
75+ model_dir , "is_init_handler" )
76+
77+ # Detect SageMaker script pattern for backward compatibility
78+ self ._detect_sagemaker_functions (model_dir )
79+
4280 logger .info (
43- f"Loaded formatters - input: { self .input_formatter } , output: { self .output_formatter } "
81+ f"Loaded formatters - input: { bool (self .custom_code .formatters .input_formatter )} , "
82+ f"output: { bool (self .custom_code .formatters .output_formatter )} "
4483 )
84+ logger .info (
85+ f"Loaded handlers - prediction: { bool (self .custom_code .handlers .prediction_handler )} , "
86+ f"init: { bool (self .custom_code .handlers .init_handler )} , "
87+ f"sagemaker: { self .custom_code .is_sagemaker_script } " )
88+ return self .custom_code
4589 except Exception as e :
4690 raise CustomFormatterError (
47- f"Failed to load custom formatters from { model_dir } " , e )
91+ f"Failed to load custom code from { model_dir } " , e )
92+
93+ def _detect_sagemaker_functions (self , model_dir : str ):
94+ """Detect and load SageMaker-style functions for backward compatibility"""
95+ # If no decorator-based code found, check for SageMaker functions
96+ if not any ([
97+ self .custom_code .formatters .input_formatter ,
98+ self .custom_code .formatters .output_formatter ,
99+ self .custom_code .handlers .prediction_handler ,
100+ self .custom_code .handlers .init_handler
101+ ]):
102+ sagemaker_model_fn = get_sagemaker_function (model_dir , 'model_fn' )
103+ sagemaker_input_fn = get_sagemaker_function (model_dir , 'input_fn' )
104+ sagemaker_predict_fn = get_sagemaker_function (
105+ model_dir , 'predict_fn' )
106+ sagemaker_output_fn = get_sagemaker_function (
107+ model_dir , 'output_fn' )
108+
109+ if any ([
110+ sagemaker_model_fn , sagemaker_input_fn ,
111+ sagemaker_predict_fn , sagemaker_output_fn
112+ ]):
113+ self .custom_code .is_sagemaker_script = True
114+ if sagemaker_model_fn :
115+ self .custom_code .handlers .init_handler = sagemaker_model_fn
116+ if sagemaker_input_fn :
117+ self .custom_code .formatters .input_formatter = sagemaker_input_fn
118+ if sagemaker_predict_fn :
119+ self .custom_code .handlers .prediction_handler = sagemaker_predict_fn
120+ if sagemaker_output_fn :
121+ self .custom_code .formatters .output_formatter = sagemaker_output_fn
122+ logger .info ("Loaded SageMaker-style functions" )
48123
49124 def apply_input_formatter (self , decoded_payload , ** kwargs ):
50125 """Apply input formatter if available"""
51- if self .input_formatter :
126+ if self .custom_code . formatters . input_formatter :
52127 try :
53- return self .input_formatter (decoded_payload , ** kwargs )
128+ return self .custom_code .formatters .input_formatter (
129+ decoded_payload , ** kwargs )
54130 except Exception as e :
55131 logger .exception ("Custom input formatter failed" )
56132 raise CustomFormatterError (
@@ -59,9 +135,9 @@ def apply_input_formatter(self, decoded_payload, **kwargs):
59135
60136 def apply_output_formatter (self , output ):
61137 """Apply output formatter if available"""
62- if self .output_formatter :
138+ if self .custom_code . formatters . output_formatter :
63139 try :
64- return self .output_formatter (output )
140+ return self .custom_code . formatters . output_formatter (output )
65141 except Exception as e :
66142 logger .exception ("Custom output formatter failed" )
67143 raise CustomFormatterError (
@@ -79,3 +155,9 @@ async def apply_output_formatter_streaming_raw(self, stream_generator):
79155 logger .exception ("Streaming formatter failed" )
80156 raise CustomFormatterError (
81157 "Custom streaming formatter execution failed" , e )
158+
159+
160+ def load_custom_code (model_dir : str ) -> CustomCode :
161+ """Load custom code, checking DJL decorators first, then SageMaker functions"""
162+ handler = CustomFormatterHandler ()
163+ return handler .load_formatters (model_dir )
0 commit comments