1+ import functools
2+ import inspect
3+ import logging
4+ from io import BytesIO
5+ from typing import Any , Callable , List , Optional , Tuple
6+ import requests
7+
8+ from .utils import (
9+ UrlType ,
10+ is_url ,
11+ generate_object_name ,
12+ detect_content_type_from_bytes ,
13+ guess_extension_from_content_type ,
14+ parse_s3_url
15+ )
16+
17+ logger = logging .getLogger ("multi_modal" )
18+
19+
20+ class LoadSaveObjectManager :
21+ """
22+ Provide load/save decorators that operate on a specific storage client.
23+
24+ The manager can be instantiated with a storage client and exposes decorator
25+ factories for `load_object` and `save_object`. A default module-level manager
26+ is also provided for backwards compatibility with existing helper functions.
27+ """
28+
29+ def __init__ (self , storage_client : Any ):
30+ self ._storage_client = storage_client
31+
32+ def _get_client (self ) -> Any :
33+ """
34+ Return a ready-to-use storage client, ensuring initialization first.
35+ """
36+ if self ._storage_client is None :
37+ raise ValueError ("Storage client is not initialized." )
38+ return self ._storage_client
39+
40+ def download_file_from_url (
41+ self ,
42+ url : str ,
43+ url_type : UrlType ,
44+ timeout : int = 30
45+ ) -> Optional [bytes ]:
46+ """
47+ Download file content from S3 URL or HTTP/HTTPS URL as bytes.
48+ """
49+ if not url :
50+ return None
51+
52+ if not url_type :
53+ raise ValueError ("url_type must be provided for download_file_from_url" )
54+
55+ try :
56+ if url_type in ("http" , "https" ):
57+ response = requests .get (url , timeout = timeout )
58+ response .raise_for_status ()
59+ return response .content
60+
61+ if url_type == "s3" :
62+ client = self ._get_client ()
63+ bucket , object_name = parse_s3_url (url )
64+
65+ if not hasattr (client , 'get_file_stream' ):
66+ raise ValueError ("Storage client does not have get_file_stream method" )
67+
68+ success , stream = client .get_file_stream (object_name , bucket )
69+ if not success :
70+ raise ValueError (f"Failed to get file stream from storage: { stream } " )
71+
72+ try :
73+ bytes_data = stream .read ()
74+ if hasattr (stream , 'close' ):
75+ stream .close ()
76+ return bytes_data
77+ except Exception as exc :
78+ raise ValueError (f"Failed to read stream content: { exc } " ) from exc
79+
80+ raise ValueError (f"Unsupported URL type: { url_type } " )
81+
82+ except Exception as exc :
83+ logger .error (f"Failed to download file from URL: { exc } " )
84+ return None
85+
86+ def _upload_bytes_to_minio (
87+ self ,
88+ bytes_data : bytes ,
89+ object_name : Optional [str ] = None ,
90+ bucket : str = "multi-modal" ,
91+ content_type : str = "application/octet-stream" ,
92+ ) -> str :
93+ """
94+ Upload bytes to MinIO and return the resulting file URL.
95+ """
96+ client = self ._get_client ()
97+
98+ if not hasattr (client , 'upload_fileobj' ):
99+ raise ValueError ("Storage client must have upload_fileobj method" )
100+
101+ if object_name is None :
102+ file_ext = guess_extension_from_content_type (content_type )
103+ object_name = generate_object_name (file_ext )
104+
105+ file_obj = BytesIO (bytes_data )
106+ success , result = client .upload_fileobj (file_obj , object_name , bucket )
107+
108+ if not success :
109+ raise ValueError (f"Failed to upload file to MinIO: { result } " )
110+
111+ return result
112+
113+ def load_object (
114+ self ,
115+ input_names : List [str ],
116+ input_data_transformer : Optional [List [Callable [[bytes ], Any ]]] = None ,
117+ ):
118+ """
119+ Decorator factory that downloads inputs before invoking the wrapped callable.
120+ """
121+
122+ def decorator (func : Callable ):
123+ @functools .wraps (func )
124+ def wrapper (* args , ** kwargs ):
125+ def _transform_single_value (param_name : str , value : Any ,
126+ transformer : Optional [Callable [[bytes ], Any ]]) -> Any :
127+ if isinstance (value , str ):
128+ url_type = is_url (value )
129+ if url_type :
130+ bytes_data = self .download_file_from_url (value , url_type = url_type )
131+
132+ if bytes_data is None :
133+ raise ValueError (f"Failed to download file from URL: { value } " )
134+
135+ if transformer :
136+ transformed_data = transformer (bytes_data )
137+ logger .info (
138+ f"Downloaded { param_name } from URL and transformed "
139+ f"using { transformer .__name__ } "
140+ )
141+ return transformed_data
142+
143+ logger .info (f"Downloaded { param_name } from URL as bytes (binary stream)" )
144+ return bytes_data
145+
146+ raise ValueError (
147+ f"Parameter '{ param_name } ' is not a URL string. "
148+ f"load_object decorator expects S3 or HTTP/HTTPS URLs. "
149+ f"Got: { type (value ).__name__ } "
150+ )
151+
152+ def _process_value (param_name : str , value : Any ,
153+ transformer : Optional [Callable [[bytes ], Any ]]) -> Any :
154+ if value is None :
155+ return None
156+
157+ if isinstance (value , (list , tuple )):
158+ processed_items = [
159+ _process_value (param_name , item , transformer )
160+ for item in value
161+ ]
162+ return type (value )(processed_items )
163+
164+ return _transform_single_value (param_name , value , transformer )
165+
166+ sig = inspect .signature (func )
167+ bound_args = sig .bind (* args , ** kwargs )
168+ bound_args .apply_defaults ()
169+
170+ for i , param_name in enumerate (input_names ):
171+ if param_name not in bound_args .arguments :
172+ continue
173+
174+ original_data = bound_args .arguments [param_name ]
175+ if original_data is None :
176+ continue
177+
178+ transformer_func = (
179+ input_data_transformer [i ]
180+ if input_data_transformer and i < len (input_data_transformer )
181+ else None
182+ )
183+
184+ transformed_data = _process_value (param_name , original_data , transformer_func )
185+ bound_args .arguments [param_name ] = transformed_data
186+
187+ return func (* bound_args .args , ** bound_args .kwargs )
188+
189+ return wrapper
190+
191+ return decorator
192+
193+ def save_object (
194+ self ,
195+ output_names : List [str ],
196+ output_transformers : Optional [List [Callable [[Any ], bytes ]]] = None ,
197+ bucket : str = "multi-modal" ,
198+ ):
199+ """
200+ Decorator factory that uploads outputs to storage after function execution.
201+ """
202+
203+ def decorator (func : Callable ) -> Callable :
204+ def _handle_results (results : Any ):
205+ if not isinstance (results , tuple ):
206+ results_tuple = (results ,)
207+ else :
208+ results_tuple = results
209+
210+ if len (results_tuple ) != len (output_names ):
211+ raise ValueError (
212+ f"Function returned { len (results_tuple )} values, "
213+ f"but expected { len (output_names )} outputs"
214+ )
215+
216+ def _upload_single_output (
217+ name : str ,
218+ value : Any ,
219+ transformer : Optional [Callable [[Any ], bytes ]]
220+ ) -> str :
221+ if transformer :
222+ bytes_data = transformer (value )
223+ if not isinstance (bytes_data , bytes ):
224+ raise ValueError (
225+ f"Transformer { transformer .__name__ } for { name } must return bytes, "
226+ f"got { type (bytes_data ).__name__ } "
227+ )
228+ logger .info (f"Transformed { name } using { transformer .__name__ } to bytes" )
229+ else :
230+ if not isinstance (value , bytes ):
231+ raise ValueError (
232+ f"Return value for { name } must be bytes when no transformer is provided, "
233+ f"got { type (value ).__name__ } "
234+ )
235+ bytes_data = value
236+ logger .info (f"Using { name } as bytes directly" )
237+
238+ content_type = detect_content_type_from_bytes (bytes_data )
239+ logger .info (f"Detected content type for { name } : { content_type } " )
240+
241+ file_url = self ._upload_bytes_to_minio (
242+ bytes_data ,
243+ object_name = None ,
244+ content_type = content_type ,
245+ bucket = bucket ,
246+ )
247+ logger .info (f"Uploaded { name } to MinIO: { file_url } " )
248+ return "s3:/" + file_url
249+
250+ def _process_output_value (
251+ name : str ,
252+ value : Any ,
253+ transformer : Optional [Callable [[Any ], bytes ]]
254+ ) -> Any :
255+ if value is None :
256+ return None
257+
258+ if isinstance (value , (list , tuple )):
259+ processed_items = [
260+ _process_output_value (name , item , transformer )
261+ for item in value
262+ ]
263+ return type (value )(processed_items )
264+
265+ return _upload_single_output (name , value , transformer )
266+
267+ uploaded_urls = []
268+ for i , (result , name ) in enumerate (zip (results_tuple , output_names )):
269+ transformer_func = (
270+ output_transformers [i ]
271+ if output_transformers and i < len (output_transformers )
272+ else None
273+ )
274+ processed_result = _process_output_value (name , result , transformer_func )
275+ uploaded_urls .append (processed_result )
276+
277+ if len (uploaded_urls ) == 1 :
278+ return uploaded_urls [0 ]
279+ return tuple (uploaded_urls )
280+
281+ if inspect .iscoroutinefunction (func ):
282+ @functools .wraps (func )
283+ async def async_wrapper (* args , ** kwargs ):
284+ results = await func (* args , ** kwargs )
285+ return _handle_results (results )
286+
287+ return async_wrapper
288+
289+ @functools .wraps (func )
290+ def wrapper (* args , ** kwargs ):
291+ results = func (* args , ** kwargs )
292+ return _handle_results (results )
293+
294+ return wrapper
295+
296+ return decorator
0 commit comments