Skip to content

Commit 466d533

Browse files
sdks/python; add OpenAITextEmbeddings to RAG module
1 parent 9031331 commit 466d533

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""RAG-specific embedding implementations using OpenAI models."""
18+
19+
from typing import Optional
20+
21+
import openai
22+
23+
import apache_beam as beam
24+
from apache_beam.ml.inference.base import RunInference
25+
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
26+
from apache_beam.ml.rag.types import Chunk
27+
from apache_beam.ml.transforms.base import EmbeddingsManager
28+
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
29+
from apache_beam.ml.transforms.embeddings.open_ai import (
30+
_OpenAITextEmbeddingHandler,
31+
)
32+
33+
__all__ = ['OpenAITextEmbeddings']
34+
35+
36+
class OpenAITextEmbeddings(EmbeddingsManager):
37+
def __init__(
38+
self,
39+
model_name: str,
40+
*,
41+
api_key: Optional[str] = None,
42+
organization: Optional[str] = None,
43+
dimensions: Optional[int] = None,
44+
user: Optional[str] = None,
45+
max_batch_size: Optional[int] = None,
46+
**kwargs):
47+
"""Utilizes OpenAI text embeddings for semantic search and RAG pipelines.
48+
49+
Args:
50+
model_name: Name of the OpenAI embedding model
51+
api_key: OpenAI API key
52+
organization: OpenAI organization ID
53+
dimensions: Specific embedding dimensions to use (if supported)
54+
user: End-user identifier for tracking and rate limit calculations
55+
max_batch_size: Maximum batch size for requests to OpenAI API
56+
**kwargs: Additional arguments passed to EmbeddingsManager including
57+
ModelHandler inference_args.
58+
"""
59+
super().__init__(type_adapter=create_rag_adapter(), **kwargs)
60+
self.model_name = model_name
61+
self.api_key = api_key
62+
self.organization = organization
63+
self.dimensions = dimensions
64+
self.user = user
65+
self.max_batch_size = max_batch_size
66+
67+
def get_model_handler(self):
68+
"""Returns model handler configured with RAG adapter."""
69+
return _OpenAITextEmbeddingHandler(
70+
model_name=self.model_name,
71+
api_key=self.api_key,
72+
organization=self.organization,
73+
dimensions=self.dimensions,
74+
user=self.user,
75+
max_batch_size=self.max_batch_size,
76+
)
77+
78+
def get_ptransform_for_processing(
79+
self, **kwargs
80+
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
81+
"""Returns PTransform that uses the RAG adapter."""
82+
return RunInference(
83+
model_handler=_TextEmbeddingHandler(self),
84+
inference_args=self.inference_args).with_output_types(Chunk)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import os
18+
import shutil
19+
import tempfile
20+
import unittest
21+
22+
import apache_beam as beam
23+
from apache_beam.ml.rag.types import Chunk
24+
from apache_beam.ml.rag.types import Content
25+
from apache_beam.ml.rag.types import Embedding
26+
from apache_beam.ml.transforms.base import MLTransform
27+
from apache_beam.testing.test_pipeline import TestPipeline
28+
from apache_beam.testing.util import assert_that
29+
from apache_beam.testing.util import equal_to
30+
from apache_beam.ml.rag.embeddings.open_ai import OpenAITextEmbeddings
31+
32+
def chunk_approximately_equals(expected, actual):
33+
"""Compare embeddings allowing for numerical differences."""
34+
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
35+
return False
36+
37+
return (
38+
expected.id == actual.id and expected.metadata == actual.metadata and
39+
expected.content == actual.content and
40+
len(expected.embedding.dense_embedding) == len(
41+
actual.embedding.dense_embedding) and
42+
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
43+
44+
45+
class OpenAITextEmbeddingsTest(unittest.TestCase):
46+
def setUp(self):
47+
self.artifact_location = tempfile.mkdtemp(prefix='openai_')
48+
self.test_chunks = [
49+
Chunk(
50+
content=Content(text="This is a test sentence."),
51+
id="1",
52+
metadata={
53+
"source": "test.txt", "language": "en"
54+
}),
55+
Chunk(
56+
content=Content(text="Another example."),
57+
id="2",
58+
metadata={
59+
"source": "test.txt", "language": "en"
60+
})
61+
]
62+
63+
def tearDown(self) -> None:
64+
shutil.rmtree(self.artifact_location)
65+
66+
def test_embedding_pipeline(self):
67+
expected = [
68+
Chunk(
69+
id="1",
70+
embedding=Embedding(dense_embedding=[0.0] * 1536),
71+
metadata={
72+
"source": "test.txt", "language": "en"
73+
},
74+
content=Content(text="This is a test sentence.")),
75+
Chunk(
76+
id="2",
77+
embedding=Embedding(dense_embedding=[0.0] * 1536),
78+
metadata={
79+
"source": "test.txt", "language": "en"
80+
},
81+
content=Content(text="Another example."))
82+
]
83+
84+
embedder = OpenAITextEmbeddings(
85+
model_name="text-embedding-3-small",
86+
dimensions=1536,
87+
api_key=os.environ.get("OPENAI_API_KEY"))
88+
89+
with TestPipeline() as p:
90+
embeddings = (
91+
p
92+
| beam.Create(self.test_chunks)
93+
| MLTransform(write_artifact_location=self.artifact_location).
94+
with_transform(embedder))
95+
96+
assert_that(
97+
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
98+
99+
def test_embedding_pipeline_with_dimensions(self):
100+
expected = [
101+
Chunk(
102+
id="1",
103+
embedding=Embedding(dense_embedding=[0.0] * 512),
104+
metadata={
105+
"source": "test.txt", "language": "en"
106+
},
107+
content=Content(text="This is a test sentence.")),
108+
Chunk(
109+
id="2",
110+
embedding=Embedding(dense_embedding=[0.0] * 512),
111+
metadata={
112+
"source": "test.txt", "language": "en"
113+
},
114+
content=Content(text="Another example."))
115+
]
116+
117+
embedder = OpenAITextEmbeddings(
118+
model_name="text-embedding-3-small",
119+
dimensions=512,
120+
api_key=os.environ.get("OPENAI_API_KEY"))
121+
122+
with TestPipeline() as p:
123+
embeddings = (
124+
p
125+
| beam.Create(self.test_chunks)
126+
| MLTransform(write_artifact_location=self.artifact_location).
127+
with_transform(embedder))
128+
129+
assert_that(
130+
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
131+
132+
133+
if __name__ == '__main__':
134+
unittest.main()

0 commit comments

Comments
 (0)