Skip to content

Commit d5e3766

Browse files
authored
complete implementation of open ai text embedding with test #new (#34700)
* openai embedding text implementation * openai embedding unit test * Update open_ai.py * Update open_ai_test.py * Update open_ai.py * Update open_ai_test.py * Create open_ai_test_requirement.txt * remove the file unless given good detail for the project * Rename open_ai_test.py to open_ai_it_test.py * changes based on lint and format error * bypass the type check * trying to fix the error related to the linting * chnages latest * changes * changes * Update open_ai.py * Update open_ai.py * Update open_ai.py * Update open_ai.py * completed the whole open_ai * commit with changes to request of the reviewer * commit * change
1 parent f0c4f81 commit d5e3766

File tree

2 files changed

+459
-0
lines changed

2 files changed

+459
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one or more
2+
# contributor license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright ownership.
4+
# The ASF licenses this file to You under the Apache License, Version 2.0
5+
# (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from collections.abc import Iterable
18+
from collections.abc import Sequence
19+
from typing import Any
20+
from typing import Optional
21+
from typing import TypeVar
22+
from typing import Union
23+
24+
import apache_beam as beam
25+
import openai
26+
from apache_beam.ml.inference.base import RemoteModelHandler
27+
from apache_beam.ml.inference.base import RunInference
28+
from apache_beam.ml.transforms.base import EmbeddingsManager
29+
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
30+
from apache_beam.pvalue import PCollection
31+
from apache_beam.pvalue import Row
32+
from openai import APIError
33+
from openai import RateLimitError
34+
35+
__all__ = ["OpenAITextEmbeddings"]
36+
37+
# Define a type variable for the output
38+
MLTransformOutputT = TypeVar('MLTransformOutputT')
39+
40+
# Default batch size for OpenAI API requests
41+
_DEFAULT_BATCH_SIZE = 20
42+
43+
LOGGER = logging.getLogger("OpenAIEmbeddings")
44+
45+
46+
def _retry_on_appropriate_openai_error(exception):
47+
"""
48+
Retry filter that returns True for rate limit (429) or server (5xx) errors.
49+
50+
Args:
51+
exception: the returned exception encountered during the request/response
52+
loop.
53+
54+
Returns:
55+
boolean indication whether or not the exception is a Server Error (5xx) or
56+
a RateLimitError (429) error.
57+
"""
58+
return isinstance(exception, (RateLimitError, APIError))
59+
60+
61+
class _OpenAITextEmbeddingHandler(RemoteModelHandler):
62+
"""
63+
Note: Intended for internal use and guarantees no backwards compatibility.
64+
"""
65+
def __init__(
66+
self,
67+
model_name: str,
68+
api_key: Optional[str] = None,
69+
organization: Optional[str] = None,
70+
dimensions: Optional[int] = None,
71+
user: Optional[str] = None,
72+
max_batch_size: Optional[int] = None,
73+
):
74+
super().__init__(
75+
namespace="OpenAITextEmbeddings",
76+
num_retries=5,
77+
throttle_delay_secs=5,
78+
retry_filter=_retry_on_appropriate_openai_error)
79+
self.model_name = model_name
80+
self.api_key = api_key
81+
self.organization = organization
82+
self.dimensions = dimensions
83+
self.user = user
84+
self.max_batch_size = max_batch_size or _DEFAULT_BATCH_SIZE
85+
86+
def create_client(self):
87+
"""Creates and returns an OpenAI client."""
88+
if self.api_key:
89+
client = openai.OpenAI(
90+
api_key=self.api_key,
91+
organization=self.organization,
92+
)
93+
else:
94+
client = openai.OpenAI(organization=self.organization)
95+
96+
return client
97+
98+
def request(
99+
self,
100+
batch: Sequence[str],
101+
model: Any,
102+
inference_args: Optional[dict[str, Any]] = None,
103+
) -> Iterable:
104+
"""Makes a request to OpenAI embedding API and returns embeddings."""
105+
# Prepare arguments for the API call
106+
kwargs = {
107+
"model": self.model_name,
108+
"input": batch,
109+
}
110+
if self.dimensions:
111+
kwargs["dimensions"] = [str(self.dimensions)]
112+
if self.user:
113+
kwargs["user"] = self.user
114+
115+
# Make the API call - let RemoteModelHandler handle retries and exceptions
116+
response = model.embeddings.create(**kwargs)
117+
return [item.embedding for item in response.data]
118+
119+
def batch_elements_kwargs(self) -> dict[str, Any]:
120+
"""Return kwargs suitable for BatchElements with appropriate batch size"""
121+
return {'max_batch_size': self.max_batch_size}
122+
123+
def __repr__(self):
124+
return 'OpenAITextEmbeddings'
125+
126+
127+
class OpenAITextEmbeddings(EmbeddingsManager):
128+
@beam.typehints.with_output_types(PCollection[Union[MLTransformOutputT, Row]])
129+
def __init__(
130+
self,
131+
model_name: str,
132+
columns: list[str],
133+
api_key: Optional[str] = None,
134+
organization: Optional[str] = None,
135+
dimensions: Optional[int] = None,
136+
user: Optional[str] = None,
137+
max_batch_size: Optional[int] = None,
138+
**kwargs):
139+
"""
140+
Embedding Config for OpenAI Text Embedding models.
141+
Text Embeddings are generated for a batch of text using the OpenAI API.
142+
143+
Args:
144+
model_name: Name of the OpenAI embedding model
145+
columns: The columns where the embeddings will be stored in the output
146+
api_key: OpenAI API key
147+
organization: OpenAI organization ID
148+
dimensions: Specific embedding dimensions to use (if model supports it)
149+
user: End-user identifier for tracking and rate limit calculations
150+
max_batch_size: Maximum batch size for requests to OpenAI API
151+
"""
152+
self.model_name = model_name
153+
self.api_key = api_key
154+
self.organization = organization
155+
self.dimensions = dimensions
156+
self.user = user
157+
self.max_batch_size = max_batch_size
158+
super().__init__(columns=columns, **kwargs)
159+
160+
def get_model_handler(self) -> RemoteModelHandler:
161+
return _OpenAITextEmbeddingHandler(
162+
model_name=self.model_name,
163+
api_key=self.api_key,
164+
organization=self.organization,
165+
dimensions=self.dimensions,
166+
user=self.user,
167+
max_batch_size=self.max_batch_size,
168+
)
169+
170+
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
171+
return RunInference(
172+
model_handler=_TextEmbeddingHandler(self),
173+
inference_args=self.inference_args)

0 commit comments

Comments
 (0)