1616
1717from __future__ import annotations
1818
19- from typing import Any , Dict , Iterable , List
19+ from collections .abc import Iterable
20+ from typing import Any
2021
2122
2223class ExportEmbeddingsDataRunner :
2324 """Callable runner that merges raw data and embeddings from multiple inputs."""
2425
25- def __call__ (self , inputs : Any , ** parameters : Any ) -> Dict [str , Any ]:
26+ def __call__ (self , inputs : Any , ** parameters : Any ) -> dict [str , Any ]:
2627 data_inputs = self ._extract_data_inputs (inputs , parameters )
2728 payloads = list (self ._iter_payloads (data_inputs ))
2829
@@ -35,7 +36,7 @@ def __call__(self, inputs: Any, **parameters: Any) -> Dict[str, Any]:
3536 }
3637
3738 @staticmethod
38- def _extract_data_inputs (inputs : Any , parameters : Dict [str , Any ]) -> List [Any ]:
39+ def _extract_data_inputs (inputs : Any , parameters : dict [str , Any ]) -> list [Any ]:
3940 """Prefer explicit data_inputs in parameters; fall back to request inputs."""
4041 if "data_inputs" in parameters and parameters ["data_inputs" ]:
4142 candidates = parameters ["data_inputs" ]
@@ -51,7 +52,7 @@ def _extract_data_inputs(inputs: Any, parameters: Dict[str, Any]) -> List[Any]:
5152 @staticmethod
5253 def _iter_payloads (data_inputs : Iterable [Any ]) -> Iterable [dict ]:
5354 """Yield normalized payloads from input data.
54-
55+
5556 Handles various input structures including:
5657 - PrecomputedEmbeddings: {"type": "PrecomputedEmbeddings", "vectors": [...], "texts": [...]}
5758 - Data objects: {"data": {...}}
@@ -77,10 +78,10 @@ def _iter_payloads(data_inputs: Iterable[Any]) -> Iterable[dict]:
7778 yield {"value" : item }
7879
7980 @staticmethod
80- def _collect_raw_items (payloads : List [dict ]) -> List [Any ]:
81+ def _collect_raw_items (payloads : list [dict ]) -> list [Any ]:
8182 """Collect raw data items (text + metadata) from payloads."""
82- merged : List [Any ] = []
83-
83+ merged : list [Any ] = []
84+
8485 for payload in payloads :
8586 # Handle PrecomputedEmbeddings format
8687 if payload .get ("type" ) == "PrecomputedEmbeddings" :
@@ -92,21 +93,23 @@ def _collect_raw_items(payloads: List[dict]) -> List[Any]:
9293 item ["embeddings" ] = vectors [i ]
9394 merged .append (item )
9495 continue
95-
96+
9697 # Collect from explicit "items" array
9798 if isinstance (payload .get ("items" ), list ):
9899 merged .extend (payload ["items" ])
99100 continue
100-
101+
101102 # Collect single text entry
102103 if payload .get ("text" ):
103- merged .append ({
104- "text" : payload .get ("text" ),
105- "model" : payload .get ("model" ),
106- "embeddings" : payload .get ("embeddings" ),
107- })
104+ merged .append (
105+ {
106+ "text" : payload .get ("text" ),
107+ "model" : payload .get ("model" ),
108+ "embeddings" : payload .get ("embeddings" ),
109+ }
110+ )
108111 continue
109-
112+
110113 # Collect multiple texts
111114 if payload .get ("texts" ) and isinstance (payload ["texts" ], list ):
112115 texts = payload ["texts" ]
@@ -120,27 +123,25 @@ def _collect_raw_items(payloads: List[dict]) -> List[Any]:
120123 continue
121124
122125 # Check nested structure (legacy format)
123- nested_data = (
124- payload .get ("locals" , {})
125- .get ("output" , {})
126- .get ("data" , {})
127- )
126+ nested_data = payload .get ("locals" , {}).get ("output" , {}).get ("data" , {})
128127 if isinstance (nested_data , dict ):
129128 if isinstance (nested_data .get ("items" ), list ):
130129 merged .extend (nested_data ["items" ])
131130 elif nested_data .get ("text" ):
132- merged .append ({
133- "text" : nested_data .get ("text" ),
134- "model" : nested_data .get ("model" ),
135- "embeddings" : nested_data .get ("embeddings" ),
136- })
137-
131+ merged .append (
132+ {
133+ "text" : nested_data .get ("text" ),
134+ "model" : nested_data .get ("model" ),
135+ "embeddings" : nested_data .get ("embeddings" ),
136+ }
137+ )
138+
138139 return merged
139140
140141 @staticmethod
141- def _collect_embeddings (payloads : List [dict ]) -> List [dict ]:
142+ def _collect_embeddings (payloads : list [dict ]) -> list [dict ]:
142143 """Collect embeddings in vector-store compatible format.
143-
144+
144145 Returns list of entries with:
145146 - id: Deterministic hash of text content
146147 - vector: Embedding array (for vector DBs)
@@ -149,35 +150,40 @@ def _collect_embeddings(payloads: List[dict]) -> List[dict]:
149150 - metadata: Same as payload (for other DBs)
150151 """
151152 import hashlib
152-
153- merged : List [dict [str , Any ]] = []
154153
155- def to_entry (text : str , vector : List [float ], extra_metadata : dict | None = None ) -> dict [str , Any ]:
154+ merged : list [dict [str , Any ]] = []
155+
156+ def to_entry (
157+ text : str , vector : list [float ], extra_metadata : dict | None = None
158+ ) -> dict [str , Any ]:
156159 """Create a vector-store compatible entry."""
157160 # Normalize text
158161 if isinstance (text , dict ):
159162 text = text .get ("text" , str (text ))
160163 elif isinstance (text , list ):
161164 # Join list items
162165 text = " | " .join (
163- item .get ("title" , item .get ("text" , str (item )))
164- if isinstance (item , dict ) else str (item )
166+ (
167+ item .get ("title" , item .get ("text" , str (item )))
168+ if isinstance (item , dict )
169+ else str (item )
170+ )
165171 for item in text
166172 )
167173 else :
168174 text = str (text ) if text else ""
169-
175+
170176 # Generate deterministic ID from text content
171177 text_hash = hashlib .md5 (text .encode ()).hexdigest ()[:16 ]
172178 entry_id = f"emb-{ text_hash } "
173-
179+
174180 # Build metadata
175181 metadata = {"text" : text }
176182 if extra_metadata :
177183 for key , value in extra_metadata .items ():
178184 if key not in {"embeddings" , "vector" , "vectors" , "type" , "texts" }:
179185 metadata [key ] = value
180-
186+
181187 return {
182188 "id" : entry_id ,
183189 "vector" : vector , # Standard field name for vector DBs
@@ -192,32 +198,32 @@ def to_entry(text: str, vector: List[float], extra_metadata: dict | None = None)
192198 if payload .get ("type" ) == "PrecomputedEmbeddings" :
193199 vectors = payload .get ("vectors" , [])
194200 texts = payload .get ("texts" , [])
195-
201+
196202 for i , vector in enumerate (vectors ):
197203 if not isinstance (vector , list ):
198204 continue
199205 text = texts [i ] if i < len (texts ) else f"item_{ i } "
200206 entry = to_entry (text , vector , {"model" : payload .get ("model" )})
201207 merged .append (entry )
202208 continue
203-
209+
204210 # Get embeddings - check both "embeddings" and "vectors" keys
205211 embeddings = payload .get ("embeddings" ) or payload .get ("vectors" )
206-
212+
207213 if not embeddings or not isinstance (embeddings , list ):
208214 continue
209-
215+
210216 # Check if it's a single vector (list of floats) or multiple vectors (list of lists)
211- if embeddings and isinstance (embeddings [0 ], ( int , float ) ):
217+ if embeddings and isinstance (embeddings [0 ], int | float ):
212218 # Single vector - pair with single text
213219 text = payload .get ("text" , "" )
214220 entry = to_entry (text , embeddings , payload )
215221 merged .append (entry )
216-
222+
217223 elif embeddings and isinstance (embeddings [0 ], list ):
218224 # Multiple vectors - pair with texts array
219225 texts = payload .get ("texts" , [])
220-
226+
221227 for i , vector in enumerate (embeddings ):
222228 if not isinstance (vector , list ):
223229 continue
@@ -227,15 +233,11 @@ def to_entry(text: str, vector: List[float], extra_metadata: dict | None = None)
227233 merged .append (entry )
228234
229235 # Also check nested structure (legacy format)
230- nested_data = (
231- payload .get ("locals" , {})
232- .get ("output" , {})
233- .get ("data" , {})
234- )
236+ nested_data = payload .get ("locals" , {}).get ("output" , {}).get ("data" , {})
235237 if isinstance (nested_data , dict ):
236238 nested_embeddings = nested_data .get ("embeddings" )
237239 if isinstance (nested_embeddings , list ) and nested_embeddings :
238- if isinstance (nested_embeddings [0 ], ( int , float ) ):
240+ if isinstance (nested_embeddings [0 ], int | float ):
239241 text = nested_data .get ("text" , "" )
240242 entry = to_entry (text , nested_embeddings , nested_data )
241243 merged .append (entry )
@@ -281,4 +283,4 @@ class DFXExportEmbeddingsDataComponent(ExportEmbeddingsDataComponent):
281283 name = "DFXExportEmbeddingsDataComponent"
282284
283285
284- __all__ = ["get_component_runner" ]
286+ __all__ = ["get_component_runner" ]
0 commit comments