Skip to content

Commit fe188e3

Browse files
authored
[#34236] Add Vertex AI Multi-Modal embedding handler (#35677)
* Prototype Vertex MultiModal embedding handler * remove unused types * change temp file and artifact paths to use dedicated directories * formatting * quick unit tests for the base multimodal embedding handler * Migrate to input adapter, add testing for video * linting * isort * made segment configuration per-video instance * fix corrected video input type * speed up video test by passing a GCS URI instead of loading the video * formatting * move to wrapped inputs * clarify types in dict_input_fn * linting * fix chunk construction * update main input to use wrappers
1 parent 32aff26 commit fe188e3

File tree

4 files changed

+501
-3
lines changed

4 files changed

+501
-3
lines changed

sdks/python/apache_beam/ml/transforms/base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,42 @@ def get_metrics_namespace(self) -> str:
810810
return (
811811
self._underlying.get_metrics_namespace() or
812812
'BeamML_ImageEmbeddingHandler')
813+
814+
815+
class _MultiModalEmbeddingHandler(_EmbeddingHandler):
816+
"""
817+
A ModelHandler intended to be work on
818+
list[dict[str, TypedDict(Image, Video, str)]] inputs.
819+
820+
The inputs to the model handler are expected to be a list of dicts.
821+
822+
For example, if the original mode is used with RunInference to take a
823+
PCollection[E] to a PCollection[P], this ModelHandler would take a
824+
PCollection[dict[str, E]] to a PCollection[dict[str, P]].
825+
826+
_MultiModalEmbeddingHandler will accept an EmbeddingsManager instance, which
827+
contains the details of the model to be loaded and the inference_fn to be
828+
used. The purpose of _MultiMOdalEmbeddingHandler is to generate embeddings
829+
for image, video, and text inputs using the EmbeddingsManager instance.
830+
831+
If the input is not an Image representation column, a RuntimeError will be
832+
raised.
833+
834+
This is an internal class and offers no backwards compatibility guarantees.
835+
836+
Args:
837+
embeddings_manager: An EmbeddingsManager instance.
838+
"""
839+
def _validate_column_data(self, batch):
840+
# Don't want to require framework-specific imports
841+
# here, so just catch columns of primatives for now.
842+
if isinstance(batch[0], (int, str, float, bool)):
843+
raise TypeError(
844+
'Embeddings can only be generated on '
845+
' dict[str, dataclass] types. '
846+
f'Got dict[str, {type(batch[0])}] instead.')
847+
848+
def get_metrics_namespace(self) -> str:
849+
return (
850+
self._underlying.get_metrics_namespace() or
851+
'BeamML_MultiModalEmbeddingHandler')

sdks/python/apache_beam/ml/transforms/base_test.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import time
2424
import unittest
2525
from collections.abc import Sequence
26+
from dataclasses import dataclass
2627
from typing import Any
2728
from typing import Optional
2829

@@ -629,6 +630,122 @@ def test_handler_with_dict_inputs(self):
629630
)
630631

631632

633+
@dataclass
634+
class FakeMultiModalInput:
635+
image: Optional[PIL_Image] = None
636+
video: Optional[Any] = None
637+
text: Optional[str] = None
638+
639+
640+
class FakeMultiModalModel:
641+
def __call__(self,
642+
example: list[FakeMultiModalInput]) -> list[FakeMultiModalInput]:
643+
for i in range(len(example)):
644+
if not isinstance(example[i], FakeMultiModalInput):
645+
raise TypeError('Input must be a MultiModalInput')
646+
return example
647+
648+
649+
class FakeMultiModalModelHandler(ModelHandler):
650+
def run_inference(
651+
self,
652+
batch: Sequence[FakeMultiModalInput],
653+
model: Any,
654+
inference_args: Optional[dict[str, Any]] = None):
655+
return model(batch)
656+
657+
def load_model(self):
658+
return FakeMultiModalModel()
659+
660+
661+
class FakeMultiModalEmbeddingsManager(base.EmbeddingsManager):
662+
def __init__(self, columns, **kwargs):
663+
super().__init__(columns=columns, **kwargs)
664+
665+
def get_model_handler(self) -> ModelHandler:
666+
FakeModelHandler.__repr__ = lambda x: 'FakeMultiModalEmbeddingsManager' # type: ignore[method-assign]
667+
return FakeMultiModalModelHandler()
668+
669+
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
670+
return (RunInference(model_handler=base._MultiModalEmbeddingHandler(self)))
671+
672+
def __repr__(self):
673+
return 'FakeMultiModalEmbeddingsManager'
674+
675+
676+
class TestMultiModalEmbeddingHandler(unittest.TestCase):
677+
def setUp(self) -> None:
678+
self.embedding_config = FakeMultiModalEmbeddingsManager(columns=['x'])
679+
self.artifact_location = tempfile.mkdtemp()
680+
681+
def tearDown(self) -> None:
682+
shutil.rmtree(self.artifact_location)
683+
684+
@unittest.skipIf(PIL is None, 'PIL module is not installed.')
685+
def test_handler_with_non_dict_datatype(self):
686+
image_handler = base._MultiModalEmbeddingHandler(
687+
embeddings_manager=self.embedding_config)
688+
data = [
689+
('x', 'hi there'),
690+
('x', 'not an image'),
691+
('x', 'image_path.jpg'),
692+
]
693+
with self.assertRaises(TypeError):
694+
image_handler.run_inference(data, None, None)
695+
696+
@unittest.skipIf(PIL is None, 'PIL module is not installed.')
697+
def test_handler_with_incorrect_datatype(self):
698+
image_handler = base._MultiModalEmbeddingHandler(
699+
embeddings_manager=self.embedding_config)
700+
data = [
701+
{
702+
'x': 'hi there'
703+
},
704+
{
705+
'x': 'not an image'
706+
},
707+
{
708+
'x': 'image_path.jpg'
709+
},
710+
]
711+
with self.assertRaises(TypeError):
712+
image_handler.run_inference(data, None, None)
713+
714+
@unittest.skipIf(PIL is None, 'PIL module is not installed.')
715+
def test_handler_with_dict_inputs(self):
716+
input_one = FakeMultiModalInput(
717+
image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image one")
718+
input_two = FakeMultiModalInput(
719+
image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image two")
720+
input_three = FakeMultiModalInput(
721+
image=PIL.Image.new(mode='RGB', size=(1, 1)),
722+
video=bytes.fromhex('2Ef0 F1f2 '),
723+
text="test image three with video")
724+
data = [
725+
{
726+
'x': input_one
727+
},
728+
{
729+
'x': input_two
730+
},
731+
{
732+
'x': input_three
733+
},
734+
]
735+
expected_data = [{key: value for key, value in d.items()} for d in data]
736+
with beam.Pipeline() as p:
737+
result = (
738+
p
739+
| beam.Create(data)
740+
| base.MLTransform(
741+
write_artifact_location=self.artifact_location).with_transform(
742+
self.embedding_config))
743+
assert_that(
744+
result,
745+
equal_to(expected_data),
746+
)
747+
748+
632749
class TestUtilFunctions(unittest.TestCase):
633750
def test_dict_input_fn_normal(self):
634751
input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]

0 commit comments

Comments
 (0)