1919from djl_python import Input , Output
2020from djl_python .encode_decode import decode
2121from djl_python .utils import find_model_file
22- from djl_python .service_loader import get_annotated_function
22+ from djl_python .custom_formatter_handling import CustomFormatterHandler , CustomFormatterError
2323from djl_python .import_utils import joblib , cloudpickle , skops_io as sio
2424
2525
26- class SklearnHandler :
26+ class SklearnHandler ( CustomFormatterHandler ) :
2727
2828 def __init__ (self ):
29+ super ().__init__ ()
2930 self .model = None
3031 self .initialized = False
31- self .custom_input_formatter = None
32- self .custom_output_formatter = None
33- self .custom_predict_formatter = None
32+ self .init_properties = None
3433
3534 def _get_trusted_types (self , properties : dict ):
3635 trusted_types_str = properties .get ("skops_trusted_types" , "" )
@@ -46,6 +45,8 @@ def _get_trusted_types(self, properties: dict):
4645 return trusted_types
4746
4847 def initialize (self , properties : dict ):
48+ # Store initialization properties for use during inference
49+ self .init_properties = properties .copy ()
4950 model_dir = properties .get ("model_dir" )
5051 model_format = properties .get ("model_format" , "skops" )
5152
@@ -62,64 +63,74 @@ def initialize(self, properties: dict):
6263 f"Unsupported model format: { model_format } . Supported formats: skops, joblib, pickle, cloudpickle"
6364 )
6465
65- model_file = find_model_file ( model_dir , extensions )
66- if not model_file :
67- raise FileNotFoundError (
68- f"No model file found with format ' { model_format } ' in { model_dir } "
69- )
66+ # Load custom code
67+ try :
68+ self . load_formatters ( model_dir )
69+ except CustomFormatterError as e :
70+ raise
7071
71- if model_format == "skops" :
72- trusted_types = self ._get_trusted_types (properties )
73- self .model = sio .load (model_file , trusted = trusted_types )
72+ # Load model
73+ init_result = self .apply_init_handler (model_dir )
74+ if init_result is not None :
75+ self .model = init_result
7476 else :
75- if properties . get ( "trust_insecure_model_files" ,
76- "false" ). lower () != "true" :
77- raise ValueError (
78- f"option.trust_insecure_model_files must be set to 'true' to use { model_format } format (only skops is secure by default) "
77+ model_file = find_model_file ( model_dir , extensions )
78+ if not model_file :
79+ raise FileNotFoundError (
80+ f"No model file found with format ' { model_format } ' in { model_dir } "
7981 )
8082
81- if model_format == "joblib" :
82- self .model = joblib .load (model_file )
83- elif model_format == "pickle" :
84- with open (model_file , 'rb' ) as f :
85- self .model = pickle .load (f )
86- elif model_format == "cloudpickle" :
87- with open (model_file , 'rb' ) as f :
88- self .model = cloudpickle .load (f )
89-
90- self .custom_input_formatter = get_annotated_function (
91- model_dir , "is_input_formatter" )
92- self .custom_output_formatter = get_annotated_function (
93- model_dir , "is_output_formatter" )
94- self .custom_predict_formatter = get_annotated_function (
95- model_dir , "is_predict_formatter" )
83+ if model_format == "skops" :
84+ trusted_types = self ._get_trusted_types (properties )
85+ self .model = sio .load (model_file , trusted = trusted_types )
86+ else :
87+ if properties .get ("trust_insecure_model_files" ,
88+ "false" ).lower () != "true" :
89+ raise ValueError (
90+ f"option.trust_insecure_model_files must be set to 'true' to use { model_format } format (only skops is secure by default)"
91+ )
92+
93+ if model_format == "joblib" :
94+ self .model = joblib .load (model_file )
95+ elif model_format == "pickle" :
96+ with open (model_file , 'rb' ) as f :
97+ self .model = pickle .load (f )
98+ elif model_format == "cloudpickle" :
99+ with open (model_file , 'rb' ) as f :
100+ self .model = cloudpickle .load (f )
96101
97102 self .initialized = True
98103
99104 def inference (self , inputs : Input ) -> Output :
100105 content_type = inputs .get_property ("Content-Type" )
101- accept = inputs .get_property ("Accept" ) or "application/json"
106+ properties = inputs .get_properties ()
107+ default_accept = self .init_properties .get ("default_accept" ,
108+ "application/json" )
109+
110+ accept = inputs .get_property ("Accept" )
111+
112+ # If no accept type is specified in the request, use default
113+ if accept == "*/*" :
114+ accept = default_accept
102115
103116 # Validate accept type (skip validation if custom output formatter is provided)
104- if not self .custom_output_formatter :
105- supported_accept_types = ["application/json" , "text/csv" ]
117+ if self .output_formatter is None : # No formatter available
106118 if not any (supported_type in accept
107- for supported_type in supported_accept_types ):
119+ for supported_type in [ "application/json" , "text/csv" ] ):
108120 raise ValueError (
109- f"Unsupported Accept type: { accept } . Supported types: { supported_accept_types } "
121+ f"Unsupported Accept type: { accept } . Supported types: application/json, text/csv "
110122 )
111123
112124 # Input processing
113- X = None
114- if self .custom_input_formatter :
115- X = self .custom_input_formatter (inputs )
116- elif "text/csv" in content_type :
117- X = decode (inputs , content_type , require_csv_headers = False )
118- else :
119- input_map = decode (inputs , content_type )
120- data = input_map .get ("inputs" ) if isinstance (input_map ,
121- dict ) else input_map
122- X = np .array (data )
125+ X = self .apply_input_formatter (inputs , content_type = content_type )
126+ if X is inputs : # No formatter applied
127+ if "text/csv" in content_type :
128+ X = decode (inputs , content_type , require_csv_headers = False )
129+ else :
130+ input_map = decode (inputs , content_type )
131+ data = input_map .get ("inputs" ) if isinstance (
132+ input_map , dict ) else input_map
133+ X = np .array (data )
123134
124135 if X is None or not hasattr (X , 'ndim' ):
125136 raise ValueError (
@@ -128,18 +139,19 @@ def inference(self, inputs: Input) -> Output:
128139 if X .ndim == 1 :
129140 X = X .reshape (1 , - 1 )
130141
131- if self .custom_predict_formatter :
132- predictions = self .custom_predict_formatter (self .model , X )
133- else :
142+ predictions = self .apply_prediction_handler (X , self .model )
143+ if predictions is None :
134144 predictions = self .model .predict (X )
135145
136146 # Output processing
137- if self .custom_output_formatter :
138- return self .custom_output_formatter (predictions )
139-
140- # Supports CSV/JSON outputs by default
141147 outputs = Output ()
142- if "text/csv" in accept :
148+ formatted_output = self .apply_output_formatter (predictions ,
149+ accept = accept )
150+ if formatted_output is not predictions : # Formatter was applied
151+ if self .is_sagemaker_script :
152+ outputs .add_property ("Content-Type" , accept )
153+ outputs .add (formatted_output )
154+ elif "text/csv" in accept :
143155 csv_buffer = StringIO ()
144156 np .savetxt (csv_buffer , predictions , fmt = '%s' , delimiter = ',' )
145157 outputs .add (csv_buffer .getvalue ().rstrip ())
0 commit comments