Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
## New Features / Improvements

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* OpenAI embedding model added to Beam RAG module ([#36083](https://github.com/apache/beam/issues/36083)).

## Breaking Changes

Expand Down
5 changes: 3 additions & 2 deletions sdks/python/apache_beam/ml/inference/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
from typing import Any
from typing import Optional

from openai import AsyncOpenAI
from openai import OpenAI

from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.utils import subprocess_server
from openai import AsyncOpenAI
from openai import OpenAI

try:
# VLLM logging config breaks beam logging.
Expand Down
14 changes: 1 addition & 13 deletions sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import apache_beam as beam
from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings
from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
Expand All @@ -40,19 +41,6 @@
SENTENCE_TRANSFORMERS_AVAILABLE = False


def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))


@pytest.mark.uses_transformers
@unittest.skipIf(
not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available")
Expand Down
80 changes: 80 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/open_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RAG-specific embedding implementations using OpenAI models."""

from typing import Optional

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.open_ai import _OpenAITextEmbeddingHandler

__all__ = ['OpenAITextEmbeddings']


class OpenAITextEmbeddings(EmbeddingsManager):
def __init__(
self,
model_name: str,
*,
api_key: Optional[str] = None,
organization: Optional[str] = None,
dimensions: Optional[int] = None,
user: Optional[str] = None,
max_batch_size: Optional[int] = None,
**kwargs):
"""Utilizes OpenAI text embeddings for semantic search and RAG pipelines.

Args:
model_name: Name of the OpenAI embedding model
api_key: OpenAI API key
organization: OpenAI organization ID
dimensions: Specific embedding dimensions to use (if supported)
user: End-user identifier for tracking and rate limit calculations
max_batch_size: Maximum batch size for requests to OpenAI API
**kwargs: Additional arguments passed to EmbeddingsManager including
ModelHandler inference_args.
"""
super().__init__(type_adapter=create_rag_adapter(), **kwargs)
self.model_name = model_name
self.api_key = api_key
self.organization = organization
self.dimensions = dimensions
self.user = user
self.max_batch_size = max_batch_size

def get_model_handler(self):
"""Returns model handler configured with RAG adapter."""
return _OpenAITextEmbeddingHandler(
model_name=self.model_name,
api_key=self.api_key,
organization=self.organization,
dimensions=self.dimensions,
user=self.user,
max_batch_size=self.max_batch_size,
)

def get_ptransform_for_processing(
self, **kwargs
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
"""Returns PTransform that uses the RAG adapter."""
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(Chunk)
122 changes: 122 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/open_ai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest

import apache_beam as beam
from apache_beam.ml.rag.embeddings.open_ai import OpenAITextEmbeddings
from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to


class OpenAITextEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='openai_')
self.test_chunks = [
Chunk(
content=Content(text="This is a test sentence."),
id="1",
metadata={
"source": "test.txt", "language": "en"
}),
Chunk(
content=Content(text="Another example."),
id="2",
metadata={
"source": "test.txt", "language": "en"
})
]

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_embedding_pipeline(self):
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 1536),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 1536),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]

embedder = OpenAITextEmbeddings(
model_name="text-embedding-3-small",
dimensions=1536,
api_key=os.environ.get("OPENAI_API_KEY"))

with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))

assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))

def test_embedding_pipeline_with_dimensions(self):
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 512),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 512),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]

embedder = OpenAITextEmbeddings(
model_name="text-embedding-3-small",
dimensions=512,
api_key=os.environ.get("OPENAI_API_KEY"))

with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))

assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))


if __name__ == '__main__':
unittest.main()
32 changes: 32 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for RAG embeddings."""

from apache_beam.ml.rag.types import Chunk


def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
14 changes: 1 addition & 13 deletions sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import unittest

import apache_beam as beam
from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
Expand All @@ -38,19 +39,6 @@
VERTEX_AI_AVAILABLE = False


def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))


@unittest.skipIf(
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
class VertexAITextEmbeddingsTest(unittest.TestCase):
Expand Down
13 changes: 7 additions & 6 deletions sdks/python/apache_beam/ml/transforms/embeddings/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
from typing import TypeVar
from typing import Union

import apache_beam as beam
import openai
from openai import APIError
from openai import RateLimitError

import apache_beam as beam
from apache_beam.ml.inference.base import RemoteModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.pvalue import PCollection
from apache_beam.pvalue import Row
from openai import APIError
from openai import RateLimitError

__all__ = ["OpenAITextEmbeddings"]

Expand Down Expand Up @@ -103,12 +104,12 @@ def request(
) -> Iterable:
"""Makes a request to OpenAI embedding API and returns embeddings."""
# Prepare arguments for the API call
kwargs = {
kwargs: dict[str, Any] = {
"model": self.model_name,
"input": batch,
}
if self.dimensions:
kwargs["dimensions"] = [str(self.dimensions)]
kwargs["dimensions"] = self.dimensions
if self.user:
kwargs["user"] = self.user

Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(
"""
Embedding Config for OpenAI Text Embedding models.
Text Embeddings are generated for a batch of text using the OpenAI API.

Args:
model_name: Name of the OpenAI embedding model
columns: The columns where the embeddings will be stored in the output
Expand Down
Loading
Loading