1+ import functools
2+ import inspect
3+ import logging
4+ from io import BytesIO
5+ from typing import Any , Callable , List , Optional , Tuple
6+ import requests
7+
8+ from .utils import (
9+ is_url ,
10+ generate_object_name ,
11+ detect_content_type_from_bytes ,
12+ guess_extension_from_content_type ,
13+ parse_s3_url
14+ )
15+
16+ logger = logging .getLogger ("multi_modal" )
17+
18+
19+ class LoadSaveObjectManager :
20+ """
21+ Provide load/save decorators that operate on a specific storage client.
22+
23+ The manager can be instantiated with a storage client and exposes decorator
24+ factories for `load_object` and `save_object`. A default module-level manager
25+ is also provided for backwards compatibility with existing helper functions.
26+ """
27+
28+ def __init__ (self , storage_client : Any ):
29+ self ._storage_client = storage_client
30+
31+ def _get_client (self ) -> Any :
32+ """
33+ Return a ready-to-use storage client, ensuring initialization first.
34+ """
35+ if self ._storage_client is None :
36+ raise ValueError ("Storage client is not initialized." )
37+ return self ._storage_client
38+
39+ def download_file_from_url (self , url : str , timeout : int = 30 ) -> Optional [bytes ]:
40+ """
41+ Download file content from S3 URL or HTTP/HTTPS URL as bytes.
42+ """
43+ if not url :
44+ return None
45+
46+ try :
47+ if url .startswith (('http://' , 'https://' )):
48+ response = requests .get (url , timeout = timeout )
49+ response .raise_for_status ()
50+ return response .content
51+
52+ if url .startswith ('s3://' ) or url .startswith ('/' ):
53+ client = self ._get_client ()
54+ bucket , object_name = parse_s3_url (url )
55+
56+ if not hasattr (client , 'get_file_stream' ):
57+ raise ValueError ("Storage client does not have get_file_stream method" )
58+
59+ success , stream = client .get_file_stream (object_name , bucket )
60+ if not success :
61+ raise ValueError (f"Failed to get file stream from storage: { stream } " )
62+
63+ try :
64+ bytes_data = stream .read ()
65+ if hasattr (stream , 'close' ):
66+ stream .close ()
67+ return bytes_data
68+ except Exception as exc :
69+ raise ValueError (f"Failed to read stream content: { exc } " ) from exc
70+
71+ raise ValueError (f"Unsupported URL format: { url [:50 ]} ..." )
72+
73+ except Exception as exc :
74+ logger .error (f"Failed to download file from URL: { exc } " )
75+ return None
76+
77+ def _upload_bytes_to_minio (
78+ self ,
79+ bytes_data : bytes ,
80+ object_name : Optional [str ] = None ,
81+ bucket : str = "multi-modal" ,
82+ content_type : str = "application/octet-stream" ,
83+ ) -> str :
84+ """
85+ Upload bytes to MinIO and return the resulting file URL.
86+ """
87+ client = self ._get_client ()
88+
89+ if not hasattr (client , 'upload_fileobj' ):
90+ raise ValueError ("Storage client must have upload_fileobj method" )
91+
92+ if object_name is None :
93+ file_ext = guess_extension_from_content_type (content_type )
94+ object_name = generate_object_name (file_ext )
95+
96+ file_obj = BytesIO (bytes_data )
97+ success , result = client .upload_fileobj (file_obj , object_name , bucket )
98+
99+ if not success :
100+ raise ValueError (f"Failed to upload file to MinIO: { result } " )
101+
102+ return result
103+
104+ def load_object (
105+ self ,
106+ input_names : List [str ],
107+ input_data_transformer : Optional [List [Callable [[bytes ], Any ]]] = None ,
108+ ):
109+ """
110+ Decorator factory that downloads inputs before invoking the wrapped callable.
111+ """
112+
113+ def decorator (func : Callable ):
114+ @functools .wraps (func )
115+ def wrapper (* args , ** kwargs ):
116+ def _transform_single_value (param_name : str , value : Any ,
117+ transformer : Optional [Callable [[bytes ], Any ]]) -> Any :
118+ if isinstance (value , str ) and is_url (value ):
119+ bytes_data = self .download_file_from_url (value )
120+
121+ if bytes_data is None :
122+ raise ValueError (f"Failed to download file from URL: { value } " )
123+
124+ if transformer :
125+ transformed_data = transformer (bytes_data )
126+ logger .info (
127+ f"Downloaded { param_name } from URL and transformed "
128+ f"using { transformer .__name__ } "
129+ )
130+ return transformed_data
131+
132+ logger .info (f"Downloaded { param_name } from URL as bytes (binary stream)" )
133+ return bytes_data
134+
135+ raise ValueError (
136+ f"Parameter '{ param_name } ' is not a URL string. "
137+ f"load_object decorator expects S3 or HTTP/HTTPS URLs. "
138+ f"Got: { type (value ).__name__ } "
139+ )
140+
141+ def _process_value (param_name : str , value : Any ,
142+ transformer : Optional [Callable [[bytes ], Any ]]) -> Any :
143+ if value is None :
144+ return None
145+
146+ if isinstance (value , (list , tuple )):
147+ processed_items = [
148+ _process_value (param_name , item , transformer )
149+ for item in value
150+ ]
151+ return type (value )(processed_items )
152+
153+ return _transform_single_value (param_name , value , transformer )
154+
155+ sig = inspect .signature (func )
156+ bound_args = sig .bind (* args , ** kwargs )
157+ bound_args .apply_defaults ()
158+
159+ for i , param_name in enumerate (input_names ):
160+ if param_name not in bound_args .arguments :
161+ continue
162+
163+ original_data = bound_args .arguments [param_name ]
164+ if original_data is None :
165+ continue
166+
167+ transformer_func = (
168+ input_data_transformer [i ]
169+ if input_data_transformer and i < len (input_data_transformer )
170+ else None
171+ )
172+
173+ transformed_data = _process_value (param_name , original_data , transformer_func )
174+ bound_args .arguments [param_name ] = transformed_data
175+
176+ return func (* bound_args .args , ** bound_args .kwargs )
177+
178+ return wrapper
179+
180+ return decorator
181+
182+ def save_object (
183+ self ,
184+ output_names : List [str ],
185+ output_transformers : Optional [List [Callable [[Any ], bytes ]]] = None ,
186+ bucket : str = "multi-modal" ,
187+ ):
188+ """
189+ Decorator factory that uploads outputs to storage after function execution.
190+ """
191+
192+ def decorator (func : Callable ) -> Callable :
193+ def _handle_results (results : Any ):
194+ if not isinstance (results , tuple ):
195+ results_tuple = (results ,)
196+ else :
197+ results_tuple = results
198+
199+ if len (results_tuple ) != len (output_names ):
200+ raise ValueError (
201+ f"Function returned { len (results_tuple )} values, "
202+ f"but expected { len (output_names )} outputs"
203+ )
204+
205+ def _upload_single_output (
206+ name : str ,
207+ value : Any ,
208+ transformer : Optional [Callable [[Any ], bytes ]]
209+ ) -> str :
210+ if transformer :
211+ bytes_data = transformer (value )
212+ if not isinstance (bytes_data , bytes ):
213+ raise ValueError (
214+ f"Transformer { transformer .__name__ } for { name } must return bytes, "
215+ f"got { type (bytes_data ).__name__ } "
216+ )
217+ logger .info (f"Transformed { name } using { transformer .__name__ } to bytes" )
218+ else :
219+ if not isinstance (value , bytes ):
220+ raise ValueError (
221+ f"Return value for { name } must be bytes when no transformer is provided, "
222+ f"got { type (value ).__name__ } "
223+ )
224+ bytes_data = value
225+ logger .info (f"Using { name } as bytes directly" )
226+
227+ content_type = detect_content_type_from_bytes (bytes_data )
228+ logger .info (f"Detected content type for { name } : { content_type } " )
229+
230+ file_url = self ._upload_bytes_to_minio (
231+ bytes_data ,
232+ object_name = None ,
233+ content_type = content_type ,
234+ bucket = bucket ,
235+ )
236+ logger .info (f"Uploaded { name } to MinIO: { file_url } " )
237+ return "s3:/" + file_url
238+
239+ def _process_output_value (
240+ name : str ,
241+ value : Any ,
242+ transformer : Optional [Callable [[Any ], bytes ]]
243+ ) -> Any :
244+ if value is None :
245+ return None
246+
247+ if isinstance (value , (list , tuple )):
248+ processed_items = [
249+ _process_output_value (name , item , transformer )
250+ for item in value
251+ ]
252+ return type (value )(processed_items )
253+
254+ return _upload_single_output (name , value , transformer )
255+
256+ uploaded_urls = []
257+ for i , (result , name ) in enumerate (zip (results_tuple , output_names )):
258+ transformer_func = (
259+ output_transformers [i ]
260+ if output_transformers and i < len (output_transformers )
261+ else None
262+ )
263+ processed_result = _process_output_value (name , result , transformer_func )
264+ uploaded_urls .append (processed_result )
265+
266+ if len (uploaded_urls ) == 1 :
267+ return uploaded_urls [0 ]
268+ return tuple (uploaded_urls )
269+
270+ if inspect .iscoroutinefunction (func ):
271+ @functools .wraps (func )
272+ async def async_wrapper (* args , ** kwargs ):
273+ results = await func (* args , ** kwargs )
274+ return _handle_results (results )
275+
276+ return async_wrapper
277+
278+ @functools .wraps (func )
279+ def wrapper (* args , ** kwargs ):
280+ results = func (* args , ** kwargs )
281+ return _handle_results (results )
282+
283+ return wrapper
284+
285+ return decorator
0 commit comments