Skip to content

Commit 4b4624a

Browse files
committed
✨ Multi-modal data sdk
1 parent 7be849a commit 4b4624a

File tree

5 files changed

+1234
-0
lines changed

5 files changed

+1234
-0
lines changed

sdk/nexent/multi_modal/__init__.py

Whitespace-only changes.
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import functools
2+
import inspect
3+
import logging
4+
from io import BytesIO
5+
from typing import Any, Callable, List, Optional, Tuple
6+
import requests
7+
8+
from .utils import (
9+
is_url,
10+
generate_object_name,
11+
detect_content_type_from_bytes,
12+
guess_extension_from_content_type,
13+
parse_s3_url
14+
)
15+
16+
logger = logging.getLogger("multi_modal")
17+
18+
19+
class LoadSaveObjectManager:
20+
"""
21+
Provide load/save decorators that operate on a specific storage client.
22+
23+
The manager can be instantiated with a storage client and exposes decorator
24+
factories for `load_object` and `save_object`. A default module-level manager
25+
is also provided for backwards compatibility with existing helper functions.
26+
"""
27+
28+
def __init__(self, storage_client: Any):
29+
self._storage_client = storage_client
30+
31+
def _get_client(self) -> Any:
32+
"""
33+
Return a ready-to-use storage client, ensuring initialization first.
34+
"""
35+
if self._storage_client is None:
36+
raise ValueError("Storage client is not initialized.")
37+
return self._storage_client
38+
39+
def download_file_from_url(self, url: str, timeout: int = 30) -> Optional[bytes]:
40+
"""
41+
Download file content from S3 URL or HTTP/HTTPS URL as bytes.
42+
"""
43+
if not url:
44+
return None
45+
46+
try:
47+
if url.startswith(('http://', 'https://')):
48+
response = requests.get(url, timeout=timeout)
49+
response.raise_for_status()
50+
return response.content
51+
52+
if url.startswith('s3://') or url.startswith('/'):
53+
client = self._get_client()
54+
bucket, object_name = parse_s3_url(url)
55+
56+
if not hasattr(client, 'get_file_stream'):
57+
raise ValueError("Storage client does not have get_file_stream method")
58+
59+
success, stream = client.get_file_stream(object_name, bucket)
60+
if not success:
61+
raise ValueError(f"Failed to get file stream from storage: {stream}")
62+
63+
try:
64+
bytes_data = stream.read()
65+
if hasattr(stream, 'close'):
66+
stream.close()
67+
return bytes_data
68+
except Exception as exc:
69+
raise ValueError(f"Failed to read stream content: {exc}") from exc
70+
71+
raise ValueError(f"Unsupported URL format: {url[:50]}...")
72+
73+
except Exception as exc:
74+
logger.error(f"Failed to download file from URL: {exc}")
75+
return None
76+
77+
def _upload_bytes_to_minio(
78+
self,
79+
bytes_data: bytes,
80+
object_name: Optional[str] = None,
81+
bucket: str = "multi-modal",
82+
content_type: str = "application/octet-stream",
83+
) -> str:
84+
"""
85+
Upload bytes to MinIO and return the resulting file URL.
86+
"""
87+
client = self._get_client()
88+
89+
if not hasattr(client, 'upload_fileobj'):
90+
raise ValueError("Storage client must have upload_fileobj method")
91+
92+
if object_name is None:
93+
file_ext = guess_extension_from_content_type(content_type)
94+
object_name = generate_object_name(file_ext)
95+
96+
file_obj = BytesIO(bytes_data)
97+
success, result = client.upload_fileobj(file_obj, object_name, bucket)
98+
99+
if not success:
100+
raise ValueError(f"Failed to upload file to MinIO: {result}")
101+
102+
return result
103+
104+
def load_object(
105+
self,
106+
input_names: List[str],
107+
input_data_transformer: Optional[List[Callable[[bytes], Any]]] = None,
108+
):
109+
"""
110+
Decorator factory that downloads inputs before invoking the wrapped callable.
111+
"""
112+
113+
def decorator(func: Callable):
114+
@functools.wraps(func)
115+
def wrapper(*args, **kwargs):
116+
def _transform_single_value(param_name: str, value: Any,
117+
transformer: Optional[Callable[[bytes], Any]]) -> Any:
118+
if isinstance(value, str) and is_url(value):
119+
bytes_data = self.download_file_from_url(value)
120+
121+
if bytes_data is None:
122+
raise ValueError(f"Failed to download file from URL: {value}")
123+
124+
if transformer:
125+
transformed_data = transformer(bytes_data)
126+
logger.info(
127+
f"Downloaded {param_name} from URL and transformed "
128+
f"using {transformer.__name__}"
129+
)
130+
return transformed_data
131+
132+
logger.info(f"Downloaded {param_name} from URL as bytes (binary stream)")
133+
return bytes_data
134+
135+
raise ValueError(
136+
f"Parameter '{param_name}' is not a URL string. "
137+
f"load_object decorator expects S3 or HTTP/HTTPS URLs. "
138+
f"Got: {type(value).__name__}"
139+
)
140+
141+
def _process_value(param_name: str, value: Any,
142+
transformer: Optional[Callable[[bytes], Any]]) -> Any:
143+
if value is None:
144+
return None
145+
146+
if isinstance(value, (list, tuple)):
147+
processed_items = [
148+
_process_value(param_name, item, transformer)
149+
for item in value
150+
]
151+
return type(value)(processed_items)
152+
153+
return _transform_single_value(param_name, value, transformer)
154+
155+
sig = inspect.signature(func)
156+
bound_args = sig.bind(*args, **kwargs)
157+
bound_args.apply_defaults()
158+
159+
for i, param_name in enumerate(input_names):
160+
if param_name not in bound_args.arguments:
161+
continue
162+
163+
original_data = bound_args.arguments[param_name]
164+
if original_data is None:
165+
continue
166+
167+
transformer_func = (
168+
input_data_transformer[i]
169+
if input_data_transformer and i < len(input_data_transformer)
170+
else None
171+
)
172+
173+
transformed_data = _process_value(param_name, original_data, transformer_func)
174+
bound_args.arguments[param_name] = transformed_data
175+
176+
return func(*bound_args.args, **bound_args.kwargs)
177+
178+
return wrapper
179+
180+
return decorator
181+
182+
def save_object(
183+
self,
184+
output_names: List[str],
185+
output_transformers: Optional[List[Callable[[Any], bytes]]] = None,
186+
bucket: str = "multi-modal",
187+
):
188+
"""
189+
Decorator factory that uploads outputs to storage after function execution.
190+
"""
191+
192+
def decorator(func: Callable) -> Callable:
193+
def _handle_results(results: Any):
194+
if not isinstance(results, tuple):
195+
results_tuple = (results,)
196+
else:
197+
results_tuple = results
198+
199+
if len(results_tuple) != len(output_names):
200+
raise ValueError(
201+
f"Function returned {len(results_tuple)} values, "
202+
f"but expected {len(output_names)} outputs"
203+
)
204+
205+
def _upload_single_output(
206+
name: str,
207+
value: Any,
208+
transformer: Optional[Callable[[Any], bytes]]
209+
) -> str:
210+
if transformer:
211+
bytes_data = transformer(value)
212+
if not isinstance(bytes_data, bytes):
213+
raise ValueError(
214+
f"Transformer {transformer.__name__} for {name} must return bytes, "
215+
f"got {type(bytes_data).__name__}"
216+
)
217+
logger.info(f"Transformed {name} using {transformer.__name__} to bytes")
218+
else:
219+
if not isinstance(value, bytes):
220+
raise ValueError(
221+
f"Return value for {name} must be bytes when no transformer is provided, "
222+
f"got {type(value).__name__}"
223+
)
224+
bytes_data = value
225+
logger.info(f"Using {name} as bytes directly")
226+
227+
content_type = detect_content_type_from_bytes(bytes_data)
228+
logger.info(f"Detected content type for {name}: {content_type}")
229+
230+
file_url = self._upload_bytes_to_minio(
231+
bytes_data,
232+
object_name=None,
233+
content_type=content_type,
234+
bucket=bucket,
235+
)
236+
logger.info(f"Uploaded {name} to MinIO: {file_url}")
237+
return "s3:/" + file_url
238+
239+
def _process_output_value(
240+
name: str,
241+
value: Any,
242+
transformer: Optional[Callable[[Any], bytes]]
243+
) -> Any:
244+
if value is None:
245+
return None
246+
247+
if isinstance(value, (list, tuple)):
248+
processed_items = [
249+
_process_output_value(name, item, transformer)
250+
for item in value
251+
]
252+
return type(value)(processed_items)
253+
254+
return _upload_single_output(name, value, transformer)
255+
256+
uploaded_urls = []
257+
for i, (result, name) in enumerate(zip(results_tuple, output_names)):
258+
transformer_func = (
259+
output_transformers[i]
260+
if output_transformers and i < len(output_transformers)
261+
else None
262+
)
263+
processed_result = _process_output_value(name, result, transformer_func)
264+
uploaded_urls.append(processed_result)
265+
266+
if len(uploaded_urls) == 1:
267+
return uploaded_urls[0]
268+
return tuple(uploaded_urls)
269+
270+
if inspect.iscoroutinefunction(func):
271+
@functools.wraps(func)
272+
async def async_wrapper(*args, **kwargs):
273+
results = await func(*args, **kwargs)
274+
return _handle_results(results)
275+
276+
return async_wrapper
277+
278+
@functools.wraps(func)
279+
def wrapper(*args, **kwargs):
280+
results = func(*args, **kwargs)
281+
return _handle_results(results)
282+
283+
return wrapper
284+
285+
return decorator

0 commit comments

Comments
 (0)