Skip to content

Commit cb678ce

Browse files
gdyregeorgedyre
andauthored
Georgedyre/openai upgrade (#2) (Azure#33369)
* Keep original question for first in conversation * Add OpenAI version 1.0 support * Fix openai client --------- Co-authored-by: George Dyre <[email protected]>
1 parent 2bcc83c commit cb678ce

File tree

1 file changed

+76
-26
lines changed
  • sdk/ai/azure-ai-generative/azure/ai/generative/synthetic

1 file changed

+76
-26
lines changed

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/qa.py

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from enum import Enum
1313
from functools import lru_cache
1414
from typing import Dict, List, Tuple, Any, Union
15-
import openai
1615
from collections import defaultdict
1716
from azure.ai.resources.entities import BaseConnection
1817
from azure.identity import DefaultAzureCredential
@@ -22,27 +21,62 @@
2221
print("In order to use qa, please install the 'qa_generation' extra of azure-ai-generative")
2322
raise e
2423

24+
try:
25+
import pkg_resources
26+
openai_version_str = pkg_resources.get_distribution("openai").version
27+
openai_version = pkg_resources.parse_version(openai_version_str)
28+
import openai
29+
if openai_version >= pkg_resources.parse_version("1.0.0"):
30+
_RETRY_ERRORS = (
31+
openai.APIConnectionError ,
32+
openai.APIError,
33+
openai.APIStatusError
34+
)
35+
else:
36+
_RETRY_ERRORS = (
37+
openai.error.ServiceUnavailableError,
38+
openai.error.APIError,
39+
openai.error.RateLimitError,
40+
openai.error.APIConnectionError,
41+
openai.error.Timeout,
42+
)
43+
44+
except ImportError as e:
45+
print("In order to use qa, please install the 'qa_generation' extra of azure-ai-generative")
46+
raise e
2547

2648
_TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
2749
activity_logger = ActivityLogger(__name__)
2850
logger, module_logger = activity_logger.package_logger, activity_logger.module_logger
2951

3052
_DEFAULT_AOAI_VERSION = "2023-07-01-preview"
3153
_MAX_RETRIES = 7
32-
_RETRY_ERRORS = (
33-
openai.error.ServiceUnavailableError,
34-
openai.error.APIError,
35-
openai.error.RateLimitError,
36-
openai.error.APIConnectionError,
37-
openai.error.Timeout,
38-
)
54+
3955

4056

4157
def _completion_with_retries(*args, **kwargs):
4258
n = 1
4359
while True:
4460
try:
45-
response = openai.ChatCompletion.create(*args, **kwargs)
61+
if openai_version >= pkg_resources.parse_version("1.0.0"):
62+
if kwargs["api_type"].lower() == "azure":
63+
from openai import AzureOpenAI
64+
client = AzureOpenAI(
65+
azure_endpoint = kwargs["api_base"],
66+
api_key=kwargs["api_key"],
67+
api_version=kwargs["api_version"]
68+
)
69+
response = client.chat.completions.create(messages=kwargs["messages"], model=kwargs["deployment_id"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
70+
else:
71+
from openai import OpenAI
72+
client = OpenAI(
73+
api_key=kwargs["api_key"],
74+
)
75+
response = client.chat.completions.create(messages=kwargs["messages"], model=kwargs["model"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
76+
return response.choices[0].message.content, dict(response.usage)
77+
else:
78+
response = openai.ChatCompletion.create(*args, **kwargs)
79+
return response["choices"][0].message.content, response["usage"]
4680
except _RETRY_ERRORS as e:
4781
if n > _MAX_RETRIES:
4882
raise
@@ -51,14 +85,31 @@ def _completion_with_retries(*args, **kwargs):
5185
time.sleep(secs)
5286
n += 1
5387
continue
54-
return response
5588

5689

5790
async def _completion_with_retries_async(*args, **kwargs):
5891
n = 1
5992
while True:
6093
try:
61-
response = await openai.ChatCompletion.acreate(*args, **kwargs)
94+
if openai_version >= pkg_resources.parse_version("1.0.0"):
95+
if kwargs["api_type"].lower() == "azure":
96+
from openai import AsyncAzureOpenAI
97+
client = AsyncAzureOpenAI(
98+
azure_endpoint = kwargs["api_base"],
99+
api_key=kwargs["api_key"],
100+
api_version=kwargs["api_version"]
101+
)
102+
response = await client.chat.completions.create(messages=kwargs["messages"], model=kwargs["deployment_id"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
103+
else:
104+
from openai import AsyncOpenAI
105+
client = AsyncOpenAI(
106+
api_key=kwargs["api_key"],
107+
)
108+
response = await client.chat.completions.create(messages=kwargs["messages"], model=kwargs["model"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
109+
return response.choices[0].message.content, dict(response.usage)
110+
else:
111+
response = openai.ChatCompletion.create(*args, **kwargs)
112+
return response["choices"][0].message.content, response["usage"]
62113
except _RETRY_ERRORS as e:
63114
if n > _MAX_RETRIES:
64115
raise
@@ -67,7 +118,6 @@ async def _completion_with_retries_async(*args, **kwargs):
67118
await asyncio.sleep(secs)
68119
n += 1
69120
continue
70-
return response
71121

72122
class OutputStructure(str, Enum):
73123
"""OutputStructure defines what structure the QAs should be written to file in."""
@@ -190,15 +240,16 @@ def _merge_token_usage(self, token_usage: Dict, token_usage2: Dict) -> Dict:
190240
return {name: count + token_usage[name] for name, count in token_usage2.items()}
191241

192242
def _modify_conversation_questions(self, questions) -> Tuple[List[str], Dict]:
193-
response = _completion_with_retries(
243+
content, usage = _completion_with_retries(
194244
messages=self._get_messages_for_modify_conversation(questions),
195245
**self._chat_completion_params,
196246
)
197-
modified_questions, _ = self._parse_qa_from_response(response["choices"][0].message.content)
198-
# Don't modify first question of conversation
247+
248+
modified_questions, _ = self._parse_qa_from_response(content)
249+
# Keep proper nouns in first question of conversation
199250
modified_questions[0] = questions[0]
200251
assert len(modified_questions) == len(questions), self._PARSING_ERR_UNEQUAL_Q_AFTER_MOD
201-
return modified_questions, response["usage"]
252+
return modified_questions, usage
202253

203254
@distributed_trace
204255
@monitor_with_activity(logger, "QADataGenerator.Export", ActivityType.INTERNALCALL)
@@ -266,13 +317,12 @@ def export_to_file(self, output_path: str, qa_type: QAType, results: Union[List,
266317
@monitor_with_activity(logger, "QADataGenerator.Generate", ActivityType.INTERNALCALL)
267318
def generate(self, text: str, qa_type: QAType, num_questions: int = None) -> Dict:
268319
self._validate(qa_type, num_questions)
269-
response = _completion_with_retries(
320+
content, token_usage = _completion_with_retries(
270321
messages=self._get_messages_for_qa_type(qa_type, text, num_questions),
271322
**self._chat_completion_params,
272323
)
273-
questions, answers = self._parse_qa_from_response(response["choices"][0].message.content)
324+
questions, answers = self._parse_qa_from_response(content)
274325
assert len(questions) == len(answers), self._PARSING_ERR_UNEQUAL_QA
275-
token_usage = response["usage"]
276326
if qa_type == QAType.CONVERSATION:
277327
questions, token_usage2 = self._modify_conversation_questions(questions)
278328
token_usage = self._merge_token_usage(token_usage, token_usage2)
@@ -282,27 +332,27 @@ def generate(self, text: str, qa_type: QAType, num_questions: int = None) -> Dic
282332
}
283333

284334
async def _modify_conversation_questions_async(self, questions) -> Tuple[List[str], Dict]:
285-
response = await _completion_with_retries_async(
335+
content, usage = await _completion_with_retries_async(
286336
messages=self._get_messages_for_modify_conversation(questions),
287337
**self._chat_completion_params,
288338
)
289-
modified_questions, _ = self._parse_qa_from_response(response["choices"][0].message.content)
290-
# Don't modify first question of conversation
339+
340+
modified_questions, _ = self._parse_qa_from_response(content)
341+
# Keep proper nouns in first question of conversation
291342
modified_questions[0] = questions[0]
292343
assert len(modified_questions) == len(questions), self._PARSING_ERR_UNEQUAL_Q_AFTER_MOD
293-
return modified_questions, response["usage"]
344+
return modified_questions, usage
294345

295346
@distributed_trace
296347
@monitor_with_activity(logger, "QADataGenerator.GenerateAsync", ActivityType.INTERNALCALL)
297348
async def generate_async(self, text: str, qa_type: QAType, num_questions: int = None) -> Dict:
298349
self._validate(qa_type, num_questions)
299-
response = await _completion_with_retries_async(
350+
content, token_usage = await _completion_with_retries_async(
300351
messages=self._get_messages_for_qa_type(qa_type, text, num_questions),
301352
**self._chat_completion_params,
302353
)
303-
questions, answers = self._parse_qa_from_response(response["choices"][0].message.content)
354+
questions, answers = self._parse_qa_from_response(content)
304355
assert len(questions) == len(answers), self._PARSING_ERR_UNEQUAL_QA
305-
token_usage = response["usage"]
306356
if qa_type == QAType.CONVERSATION:
307357
questions, token_usage2 = await self._modify_conversation_questions_async(questions)
308358
token_usage = self._merge_token_usage(token_usage, token_usage2)

0 commit comments

Comments
 (0)