diff --git a/CHANGES.md b/CHANGES.md index e8f6f79d4dbf..bf2fe433d7b2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -74,6 +74,7 @@ ## New Features / Improvements * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* OpenAI embedding model added to Beam RAG module ([#36083](https://github.com/apache/beam/issues/36083)). ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 0bb6ccd6108e..bdbee9e51fd5 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -31,12 +31,13 @@ from typing import Any from typing import Optional +from openai import AsyncOpenAI +from openai import OpenAI + from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.utils import subprocess_server -from openai import AsyncOpenAI -from openai import OpenAI try: # VLLM logging config breaks beam logging. diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py index f0b9316dcee8..8c91adafe543 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py @@ -24,6 +24,7 @@ import apache_beam as beam from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings +from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.types import Content from apache_beam.ml.rag.types import Embedding @@ -40,19 +41,6 @@ SENTENCE_TRANSFORMERS_AVAILABLE = False -def chunk_approximately_equals(expected, actual): - """Compare embeddings allowing for numerical differences.""" - if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): - return False - - return ( - expected.id == actual.id and expected.metadata == actual.metadata and - expected.content == actual.content and - len(expected.embedding.dense_embedding) == len( - actual.embedding.dense_embedding) and - all(isinstance(x, float) for x in actual.embedding.dense_embedding)) - - @pytest.mark.uses_transformers @unittest.skipIf( not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available") diff --git a/sdks/python/apache_beam/ml/rag/embeddings/open_ai.py b/sdks/python/apache_beam/ml/rag/embeddings/open_ai.py new file mode 100644 index 000000000000..1dbf168a3a02 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/open_ai.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAG-specific embedding implementations using OpenAI models.""" + +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.ml.transforms.embeddings.open_ai import _OpenAITextEmbeddingHandler + +__all__ = ['OpenAITextEmbeddings'] + + +class OpenAITextEmbeddings(EmbeddingsManager): + def __init__( + self, + model_name: str, + *, + api_key: Optional[str] = None, + organization: Optional[str] = None, + dimensions: Optional[int] = None, + user: Optional[str] = None, + max_batch_size: Optional[int] = None, + **kwargs): + """Utilizes OpenAI text embeddings for semantic search and RAG pipelines. + + Args: + model_name: Name of the OpenAI embedding model + api_key: OpenAI API key + organization: OpenAI organization ID + dimensions: Specific embedding dimensions to use (if supported) + user: End-user identifier for tracking and rate limit calculations + max_batch_size: Maximum batch size for requests to OpenAI API + **kwargs: Additional arguments passed to EmbeddingsManager including + ModelHandler inference_args. + """ + super().__init__(type_adapter=create_rag_adapter(), **kwargs) + self.model_name = model_name + self.api_key = api_key + self.organization = organization + self.dimensions = dimensions + self.user = user + self.max_batch_size = max_batch_size + + def get_model_handler(self): + """Returns model handler configured with RAG adapter.""" + return _OpenAITextEmbeddingHandler( + model_name=self.model_name, + api_key=self.api_key, + organization=self.organization, + dimensions=self.dimensions, + user=self.user, + max_batch_size=self.max_batch_size, + ) + + def get_ptransform_for_processing( + self, **kwargs + ) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]: + """Returns PTransform that uses the RAG adapter.""" + return RunInference( + model_handler=_TextEmbeddingHandler(self), + inference_args=self.inference_args).with_output_types(Chunk) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/open_ai_test.py b/sdks/python/apache_beam/ml/rag/embeddings/open_ai_test.py new file mode 100644 index 000000000000..f263d45aae60 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/open_ai_test.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.rag.embeddings.open_ai import OpenAITextEmbeddings +from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +class OpenAITextEmbeddingsTest(unittest.TestCase): + def setUp(self): + self.artifact_location = tempfile.mkdtemp(prefix='openai_') + self.test_chunks = [ + Chunk( + content=Content(text="This is a test sentence."), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }), + Chunk( + content=Content(text="Another example."), + id="2", + metadata={ + "source": "test.txt", "language": "en" + }) + ] + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_embedding_pipeline(self): + expected = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.0] * 1536), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="This is a test sentence.")), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.0] * 1536), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="Another example.")) + ] + + embedder = OpenAITextEmbeddings( + model_name="text-embedding-3-small", + dimensions=1536, + api_key=os.environ.get("OPENAI_API_KEY")) + + with TestPipeline() as p: + embeddings = ( + p + | beam.Create(self.test_chunks) + | MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedder)) + + assert_that( + embeddings, equal_to(expected, equals_fn=chunk_approximately_equals)) + + def test_embedding_pipeline_with_dimensions(self): + expected = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.0] * 512), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="This is a test sentence.")), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.0] * 512), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="Another example.")) + ] + + embedder = OpenAITextEmbeddings( + model_name="text-embedding-3-small", + dimensions=512, + api_key=os.environ.get("OPENAI_API_KEY")) + + with TestPipeline() as p: + embeddings = ( + p + | beam.Create(self.test_chunks) + | MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedder)) + + assert_that( + embeddings, equal_to(expected, equals_fn=chunk_approximately_equals)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/embeddings/test_utils.py b/sdks/python/apache_beam/ml/rag/embeddings/test_utils.py new file mode 100644 index 000000000000..7443a335b9be --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/test_utils.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for RAG embeddings.""" + +from apache_beam.ml.rag.types import Chunk + + +def chunk_approximately_equals(expected, actual): + """Compare embeddings allowing for numerical differences.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + + return ( + expected.id == actual.id and expected.metadata == actual.metadata and + expected.content == actual.content and + len(expected.embedding.dense_embedding) == len( + actual.embedding.dense_embedding) and + all(isinstance(x, float) for x in actual.embedding.dense_embedding)) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py index 320a562d5009..366975a5a76a 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py @@ -21,6 +21,7 @@ import unittest import apache_beam as beam +from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.types import Content from apache_beam.ml.rag.types import Embedding @@ -38,19 +39,6 @@ VERTEX_AI_AVAILABLE = False -def chunk_approximately_equals(expected, actual): - """Compare embeddings allowing for numerical differences.""" - if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): - return False - - return ( - expected.id == actual.id and expected.metadata == actual.metadata and - expected.content == actual.content and - len(expected.embedding.dense_embedding) == len( - actual.embedding.dense_embedding) and - all(isinstance(x, float) for x in actual.embedding.dense_embedding)) - - @unittest.skipIf( not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available") class VertexAITextEmbeddingsTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/open_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/open_ai.py index a162c333b199..121fa9839ef7 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/open_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/open_ai.py @@ -21,16 +21,17 @@ from typing import TypeVar from typing import Union -import apache_beam as beam import openai +from openai import APIError +from openai import RateLimitError + +import apache_beam as beam from apache_beam.ml.inference.base import RemoteModelHandler from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _TextEmbeddingHandler from apache_beam.pvalue import PCollection from apache_beam.pvalue import Row -from openai import APIError -from openai import RateLimitError __all__ = ["OpenAITextEmbeddings"] @@ -103,12 +104,12 @@ def request( ) -> Iterable: """Makes a request to OpenAI embedding API and returns embeddings.""" # Prepare arguments for the API call - kwargs = { + kwargs: dict[str, Any] = { "model": self.model_name, "input": batch, } if self.dimensions: - kwargs["dimensions"] = [str(self.dimensions)] + kwargs["dimensions"] = self.dimensions if self.user: kwargs["user"] = self.user @@ -139,7 +140,7 @@ def __init__( """ Embedding Config for OpenAI Text Embedding models. Text Embeddings are generated for a batch of text using the OpenAI API. - + Args: model_name: Name of the OpenAI embedding model columns: The columns where the embeddings will be stored in the output diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/open_ai_it_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/open_ai_it_test.py index 118c656c33c3..deca2671ecca 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/open_ai_it_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/open_ai_it_test.py @@ -23,15 +23,12 @@ from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms import base from apache_beam.ml.transforms.base import MLTransform - -try: - from sdks.python.apache_beam.ml.transforms.embeddings.open_ai import OpenAITextEmbeddings -except ImportError: - OpenAITextEmbeddings = None +from apache_beam.ml.transforms.embeddings.open_ai import OpenAITextEmbeddings # pylint: disable=ungrouped-imports try: import tensorflow_transform as tft + from apache_beam.ml.transforms.tft import ScaleTo01 except ImportError: tft = None @@ -40,8 +37,6 @@ model_name: str = "text-embedding-3-small" -@unittest.skipIf( - OpenAITextEmbeddings is None, 'OpenAI Python SDK is not installed.') class OpenAIEmbeddingsTest(unittest.TestCase): def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp(prefix='_openai_test') @@ -76,6 +71,7 @@ def test_embeddings_with_scale_to_0_1(self): columns=[test_query_column], api_key=self.api_key, ) + scale_config = ScaleTo01(columns=['embedding']) with beam.Pipeline() as pipeline: transformed_pcoll = ( pipeline @@ -84,10 +80,12 @@ def test_embeddings_with_scale_to_0_1(self): }]) | "MLTransform" >> MLTransform( write_artifact_location=self.artifact_location).with_transform( - embedding_config)) + embedding_config).with_transform(scale_config)) def assert_element(element): - assert max(element.feature_1) == 1 + embedding_values = element.embedding + assert 0 <= max(embedding_values) <= 1 + assert 0 <= min(embedding_values) <= 1 _ = (transformed_pcoll | beam.Map(assert_element)) @@ -186,7 +184,7 @@ def test_with_int_data_types(self): write_artifact_location=self.artifact_location).with_transform( embedding_config)) - def test_with_artifact_location(self): # pylint: disable=line-too-long + def test_with_artifact_location(self): """Local artifact location test""" secondary_artifact_location = tempfile.mkdtemp( prefix='_openai_secondary_test') @@ -231,7 +229,7 @@ def assert_element(element): # Clean up the temporary directory shutil.rmtree(secondary_artifact_location) - def test_mltransform_to_ptransform_with_openai(self): # pylint: disable=line-too-long + def test_mltransform_to_ptransform_with_openai(self): transforms = [ OpenAITextEmbeddings( columns=['x'], diff --git a/sdks/python/setup.py b/sdks/python/setup.py index e7ffc0c9780c..814e5f9fdcf4 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -161,6 +161,7 @@ def cythonize(*args, **kwargs): ] milvus_dependency = ['pymilvus>=2.5.10,<3.0.0'] +openai_dependency = ['openai>=1.0.0,<2.0.0'] def find_by_ext(root_dir, ext): @@ -418,9 +419,8 @@ def get_portability_package_data(): 'docutils>=0.18.1', 'markdown', 'pandas<2.3.0', - 'openai', 'virtualenv-clone>=0.5,<1.0', - ], + ] + openai_dependency, 'test': [ 'cloud-sql-python-connector[pg8000]>=1.0.0,<2.0.0', 'docstring-parser>=0.15,<1.0', @@ -449,7 +449,7 @@ def get_portability_package_data(): 'pg8000>=1.31.1', "PyMySQL>=1.1.0", 'oracledb>=3.1.1' - ] + milvus_dependency, + ] + milvus_dependency + openai_dependency, 'gcp': [ 'cachetools>=3.1.0,<7', 'google-api-core>=2.0.0,<3', @@ -596,7 +596,8 @@ def get_portability_package_data(): ], 'xgboost': ['xgboost>=1.6.0,<2.1.3', 'datatable==1.0.0'], 'tensorflow-hub': ['tensorflow-hub>=0.14.0,<0.16.0'], - 'milvus': milvus_dependency + 'milvus': milvus_dependency, + 'openai': openai_dependency }, zip_safe=False, # PyPI package information.