Skip to content

Commit d493275

Browse files
authored
[yaml] Fix ml embeddings issue (#35715)
* add a few more libraries * add more libraries and test * update types * fix gemeni review comments * protect input fn * update types * fix lint issues * more lint * fix ml transform imports * update warning and make list_submodules private
1 parent 72ecb47 commit d493275

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

sdks/python/apache_beam/ml/transforms/base.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,13 @@ def append_transform(self, transform: BaseOperation):
183183
"""
184184

185185

186-
def _dict_input_fn(columns: Sequence[str],
187-
batch: Sequence[Dict[str, Any]]) -> List[str]:
186+
def _dict_input_fn(
187+
columns: Sequence[str], batch: Sequence[Union[Dict[str, Any],
188+
beam.Row]]) -> List[str]:
188189
"""Extract text from specified columns in batch."""
190+
if batch and hasattr(batch[0], '_asdict'):
191+
batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
192+
189193
if not batch or not isinstance(batch[0], dict):
190194
raise TypeError(
191195
'Expected data to be dicts, got '
@@ -196,7 +200,7 @@ def _dict_input_fn(columns: Sequence[str],
196200
expected_columns = set(columns)
197201
# Process one batch item at a time
198202
for item in batch:
199-
item_keys = item.keys()
203+
item_keys = item.keys() if isinstance(item, dict) else set()
200204
if set(item_keys) != expected_keys:
201205
extra_keys = item_keys - expected_keys
202206
missing_keys = expected_keys - item_keys
@@ -212,21 +216,31 @@ def _dict_input_fn(columns: Sequence[str],
212216

213217
# Get all columns for this item
214218
for col in columns:
215-
result.append(item[col])
219+
if isinstance(item, dict):
220+
result.append(item[col])
216221
return result
217222

218223

219224
def _dict_output_fn(
220225
columns: Sequence[str],
221-
batch: Sequence[Dict[str, Any]],
222-
embeddings: Sequence[Any]) -> List[Dict[str, Any]]:
226+
batch: Sequence[Union[Dict[str, Any], beam.Row]],
227+
embeddings: Sequence[Any]) -> list[Union[dict[str, Any], beam.Row]]:
223228
"""Map embeddings back to columns in batch."""
229+
is_beam_row = False
230+
if batch and hasattr(batch[0], '_asdict'):
231+
is_beam_row = True
232+
batch = [row._asdict() for row in batch if hasattr(row, '_asdict')]
233+
224234
result = []
225235
for batch_idx, item in enumerate(batch):
226236
for col_idx, col in enumerate(columns):
227237
embedding_idx = batch_idx * len(columns) + col_idx
228-
item[col] = embeddings[embedding_idx]
238+
if isinstance(item, dict):
239+
item[col] = embeddings[embedding_idx]
229240
result.append(item)
241+
242+
if is_beam_row:
243+
result = [beam.Row(**item) for item in result if isinstance(item, dict)]
230244
return result
231245

232246

sdks/python/apache_beam/yaml/yaml_ml.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
#
1717

1818
"""This module defines yaml wrappings for some ML transforms."""
19+
import logging
20+
import pkgutil
1921
from collections.abc import Callable
22+
from importlib import import_module
2023
from typing import Any
2124
from typing import Optional
2225

@@ -29,14 +32,37 @@
2932
from apache_beam.yaml import options
3033
from apache_beam.yaml.yaml_utils import SafeLineLoader
3134

35+
36+
def _list_submodules(package):
37+
"""
38+
Lists all submodules within a given package.
39+
"""
40+
submodules = []
41+
for _, module_name, _ in pkgutil.walk_packages(
42+
package.__path__, package.__name__ + '.'):
43+
if 'test' in module_name:
44+
continue
45+
submodules.append(module_name)
46+
return submodules
47+
48+
3249
try:
3350
from apache_beam.ml.transforms import tft
3451
from apache_beam.ml.transforms.base import MLTransform
3552
# TODO(robertwb): Is this all of them?
36-
_transform_constructors = tft.__dict__
53+
_transform_constructors = {}
3754
except ImportError:
3855
tft = None # type: ignore
3956

57+
# Load all available ML Transform modules
58+
for module_name in _list_submodules(beam.ml.transforms):
59+
try:
60+
module = import_module(module_name)
61+
_transform_constructors |= module.__dict__
62+
except ImportError as e:
63+
logging.warning('Could not load ML transform module %s: %s. Please ' \
64+
'install the necessary module dependencies', module_name, e)
65+
4066

4167
class ModelHandlerProvider:
4268
handler_types: dict[str, Callable[..., "ModelHandlerProvider"]] = {}

sdks/python/apache_beam/yaml/yaml_ml_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,38 @@ def test_ml_transform(self):
8686
equal_to([5]),
8787
label='CheckVocab')
8888

89+
def test_sentence_transformer_embedding(self):
90+
SENTENCE_EMBEDDING_DIMENSION = 384
91+
DATA = [{
92+
'id': 1, 'log_message': "Error in module A"
93+
}, {
94+
'id': 2, 'log_message': "Warning in module B"
95+
}, {
96+
'id': 3, 'log_message': "Info in module C"
97+
}]
98+
ml_opts = beam.options.pipeline_options.PipelineOptions(
99+
pickle_library='cloudpickle', yaml_experimental_features=['ML'])
100+
with tempfile.TemporaryDirectory() as tempdir:
101+
with beam.Pipeline(options=ml_opts) as p:
102+
elements = p | beam.Create(DATA)
103+
result = elements | YamlTransform(
104+
f'''
105+
type: MLTransform
106+
config:
107+
write_artifact_location: {tempdir}
108+
transforms:
109+
- type: SentenceTransformerEmbeddings
110+
config:
111+
model_name: all-MiniLM-L6-v2
112+
columns: [log_message]
113+
''')
114+
115+
# Perform a basic check to ensure that embeddings are generated
116+
# and that the dimension of those embeddings is correct.
117+
actual_output = result | beam.Map(lambda x: len(x['log_message']))
118+
assert_that(
119+
actual_output, equal_to([SENTENCE_EMBEDDING_DIMENSION] * len(DATA)))
120+
89121

90122
if __name__ == '__main__':
91123
logging.getLogger().setLevel(logging.INFO)

sdks/python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ def get_portability_package_data():
518518
'pyod',
519519
'tensorflow',
520520
'tensorflow-hub',
521+
'tensorflow-text',
521522
'tensorflow-transform',
522523
'tf2onnx',
523524
'torch',

0 commit comments

Comments
 (0)