1818from typing import Optional
1919from djl_python import Input , Output
2020from djl_python .encode_decode import decode
21- from djl_python .utils import find_model_file
21+ from djl_python .utils import find_model_file , get_sagemaker_function
2222from djl_python .service_loader import get_annotated_function
2323from djl_python .import_utils import joblib , cloudpickle , skops_io as sio
2424
@@ -31,6 +31,49 @@ def __init__(self):
3131 self .custom_input_formatter = None
3232 self .custom_output_formatter = None
3333 self .custom_predict_formatter = None
34+ self .custom_model_loading_formatter = None
35+ self .init_properties = None
36+ self .is_sagemaker_script = False
37+
38+ def _load_custom_formatters (self , model_dir : str ):
39+ """Load custom formatters, checking DJL decorators first, then SageMaker functions."""
40+ # Check for DJL decorator-based custom formatters first
41+ self .custom_model_loading_formatter = get_annotated_function (
42+ model_dir , "is_model_loading_formatter" )
43+ self .custom_input_formatter = get_annotated_function (
44+ model_dir , "is_input_formatter" )
45+ self .custom_output_formatter = get_annotated_function (
46+ model_dir , "is_output_formatter" )
47+ self .custom_predict_formatter = get_annotated_function (
48+ model_dir , "is_predict_formatter" )
49+
50+ # If no decorator-based formatters found, check for SageMaker-style formatters
51+ if not any ([
52+ self .custom_input_formatter , self .custom_output_formatter ,
53+ self .custom_predict_formatter ,
54+ self .custom_model_loading_formatter
55+ ]):
56+
57+ sagemaker_model_fn = get_sagemaker_function (model_dir , 'model_fn' )
58+ sagemaker_input_fn = get_sagemaker_function (model_dir , 'input_fn' )
59+ sagemaker_predict_fn = get_sagemaker_function (
60+ model_dir , 'predict_fn' )
61+ sagemaker_output_fn = get_sagemaker_function (
62+ model_dir , 'output_fn' )
63+
64+ if any ([
65+ sagemaker_model_fn , sagemaker_input_fn ,
66+ sagemaker_predict_fn , sagemaker_output_fn
67+ ]):
68+ self .is_sagemaker_script = True
69+ if sagemaker_model_fn :
70+ self .custom_model_loading_formatter = sagemaker_model_fn
71+ if sagemaker_input_fn :
72+ self .custom_input_formatter = sagemaker_input_fn
73+ if sagemaker_predict_fn :
74+ self .custom_predict_formatter = sagemaker_predict_fn
75+ if sagemaker_output_fn :
76+ self .custom_output_formatter = sagemaker_output_fn
3477
3578 def _get_trusted_types (self , properties : dict ):
3679 trusted_types_str = properties .get ("skops_trusted_types" , "" )
@@ -46,6 +89,8 @@ def _get_trusted_types(self, properties: dict):
4689 return trusted_types
4790
4891 def initialize (self , properties : dict ):
92+ # Store initialization properties for use during inference
93+ self .init_properties = properties .copy ()
4994 model_dir = properties .get ("model_dir" )
5095 model_format = properties .get ("model_format" , "skops" )
5196
@@ -62,43 +107,51 @@ def initialize(self, properties: dict):
62107 f"Unsupported model format: { model_format } . Supported formats: skops, joblib, pickle, cloudpickle"
63108 )
64109
65- model_file = find_model_file (model_dir , extensions )
66- if not model_file :
67- raise FileNotFoundError (
68- f"No model file found with format '{ model_format } ' in { model_dir } "
69- )
110+ # Load custom formatters
111+ self ._load_custom_formatters (model_dir )
70112
71- if model_format == "skops" :
72- trusted_types = self ._get_trusted_types ( properties )
73- self .model = sio . load ( model_file , trusted = trusted_types )
113+ # Load model
114+ if self .custom_model_loading_formatter :
115+ self .model = self . custom_model_loading_formatter ( model_dir )
74116 else :
75- if properties . get ( "trust_insecure_model_files" ,
76- "false" ). lower () != "true" :
77- raise ValueError (
78- f"option.trust_insecure_model_files must be set to 'true' to use { model_format } format (only skops is secure by default) "
117+ model_file = find_model_file ( model_dir , extensions )
118+ if not model_file :
119+ raise FileNotFoundError (
120+ f"No model file found with format ' { model_format } ' in { model_dir } "
79121 )
80122
81- if model_format == "joblib" :
82- self .model = joblib .load (model_file )
83- elif model_format == "pickle" :
84- with open (model_file , 'rb' ) as f :
85- self .model = pickle .load (f )
86- elif model_format == "cloudpickle" :
87- with open (model_file , 'rb' ) as f :
88- self .model = cloudpickle .load (f )
89-
90- self .custom_input_formatter = get_annotated_function (
91- model_dir , "is_input_formatter" )
92- self .custom_output_formatter = get_annotated_function (
93- model_dir , "is_output_formatter" )
94- self .custom_predict_formatter = get_annotated_function (
95- model_dir , "is_predict_formatter" )
123+ if model_format == "skops" :
124+ trusted_types = self ._get_trusted_types (properties )
125+ self .model = sio .load (model_file , trusted = trusted_types )
126+ else :
127+ if properties .get ("trust_insecure_model_files" ,
128+ "false" ).lower () != "true" :
129+ raise ValueError (
130+ f"option.trust_insecure_model_files must be set to 'true' to use { model_format } format (only skops is secure by default)"
131+ )
132+
133+ if model_format == "joblib" :
134+ self .model = joblib .load (model_file )
135+ elif model_format == "pickle" :
136+ with open (model_file , 'rb' ) as f :
137+ self .model = pickle .load (f )
138+ elif model_format == "cloudpickle" :
139+ with open (model_file , 'rb' ) as f :
140+ self .model = cloudpickle .load (f )
96141
97142 self .initialized = True
98143
99144 def inference (self , inputs : Input ) -> Output :
100145 content_type = inputs .get_property ("Content-Type" )
101- accept = inputs .get_property ("Accept" ) or "application/json"
146+ properties = inputs .get_properties ()
147+ default_accept = self .init_properties .get ("default_accept" ,
148+ "application/json" )
149+
150+ accept = inputs .get_property ("Accept" )
151+
152+ # If no accept type is specified in the request, use default
153+ if accept == "*/*" :
154+ accept = default_accept
102155
103156 # Validate accept type (skip validation if custom output formatter is provided)
104157 if not self .custom_output_formatter :
@@ -112,7 +165,11 @@ def inference(self, inputs: Input) -> Output:
112165 # Input processing
113166 X = None
114167 if self .custom_input_formatter :
115- X = self .custom_input_formatter (inputs )
168+ if self .is_sagemaker_script :
169+ X = self .custom_input_formatter (inputs .get_as_bytes (),
170+ content_type )
171+ else :
172+ X = self .custom_input_formatter (inputs )
116173 elif "text/csv" in content_type :
117174 X = decode (inputs , content_type , require_csv_headers = False )
118175 else :
@@ -129,17 +186,20 @@ def inference(self, inputs: Input) -> Output:
129186 X = X .reshape (1 , - 1 )
130187
131188 if self .custom_predict_formatter :
132- predictions = self .custom_predict_formatter (self .model , X )
189+ predictions = self .custom_predict_formatter (X , self .model )
133190 else :
134191 predictions = self .model .predict (X )
135192
136193 # Output processing
137- if self .custom_output_formatter :
138- return self .custom_output_formatter (predictions )
139-
140- # Supports CSV/JSON outputs by default
141194 outputs = Output ()
142- if "text/csv" in accept :
195+ if self .custom_output_formatter :
196+ if self .is_sagemaker_script :
197+ data = self .custom_output_formatter (predictions , accept )
198+ outputs .add_property ("Content-Type" , accept )
199+ else :
200+ data = self .custom_output_formatter (predictions )
201+ outputs .add (data )
202+ elif "text/csv" in accept :
143203 csv_buffer = StringIO ()
144204 np .savetxt (csv_buffer , predictions , fmt = '%s' , delimiter = ',' )
145205 outputs .add (csv_buffer .getvalue ().rstrip ())
0 commit comments