12
12
from enum import Enum
13
13
from functools import lru_cache
14
14
from typing import Dict , List , Tuple , Any , Union
15
- import openai
16
15
from collections import defaultdict
17
16
from azure .ai .resources .entities import BaseConnection
18
17
from azure .identity import DefaultAzureCredential
22
21
print ("In order to use qa, please install the 'qa_generation' extra of azure-ai-generative" )
23
22
raise e
24
23
24
+ try :
25
+ import pkg_resources
26
+ openai_version_str = pkg_resources .get_distribution ("openai" ).version
27
+ openai_version = pkg_resources .parse_version (openai_version_str )
28
+ import openai
29
+ if openai_version >= pkg_resources .parse_version ("1.0.0" ):
30
+ _RETRY_ERRORS = (
31
+ openai .APIConnectionError ,
32
+ openai .APIError ,
33
+ openai .APIStatusError
34
+ )
35
+ else :
36
+ _RETRY_ERRORS = (
37
+ openai .error .ServiceUnavailableError ,
38
+ openai .error .APIError ,
39
+ openai .error .RateLimitError ,
40
+ openai .error .APIConnectionError ,
41
+ openai .error .Timeout ,
42
+ )
43
+
44
+ except ImportError as e :
45
+ print ("In order to use qa, please install the 'qa_generation' extra of azure-ai-generative" )
46
+ raise e
25
47
26
48
_TEMPLATES_DIR = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "templates" )
27
49
activity_logger = ActivityLogger (__name__ )
28
50
logger , module_logger = activity_logger .package_logger , activity_logger .module_logger
29
51
30
52
_DEFAULT_AOAI_VERSION = "2023-07-01-preview"
31
53
_MAX_RETRIES = 7
32
- _RETRY_ERRORS = (
33
- openai .error .ServiceUnavailableError ,
34
- openai .error .APIError ,
35
- openai .error .RateLimitError ,
36
- openai .error .APIConnectionError ,
37
- openai .error .Timeout ,
38
- )
54
+
39
55
40
56
41
57
def _completion_with_retries (* args , ** kwargs ):
42
58
n = 1
43
59
while True :
44
60
try :
45
- response = openai .ChatCompletion .create (* args , ** kwargs )
61
+ if openai_version >= pkg_resources .parse_version ("1.0.0" ):
62
+ if kwargs ["api_type" ].lower () == "azure" :
63
+ from openai import AzureOpenAI
64
+ client = AzureOpenAI (
65
+ azure_endpoint = kwargs ["api_base" ],
66
+ api_key = kwargs ["api_key" ],
67
+ api_version = kwargs ["api_version" ]
68
+ )
69
+ response = client .chat .completions .create (messages = kwargs ["messages" ], model = kwargs ["deployment_id" ], temperature = kwargs ["temperature" ], max_tokens = kwargs ["max_tokens" ])
70
+ else :
71
+ from openai import OpenAI
72
+ client = OpenAI (
73
+ api_key = kwargs ["api_key" ],
74
+ )
75
+ response = client .chat .completions .create (messages = kwargs ["messages" ], model = kwargs ["model" ], temperature = kwargs ["temperature" ], max_tokens = kwargs ["max_tokens" ])
76
+ return response .choices [0 ].message .content , dict (response .usage )
77
+ else :
78
+ response = openai .ChatCompletion .create (* args , ** kwargs )
79
+ return response ["choices" ][0 ].message .content , response ["usage" ]
46
80
except _RETRY_ERRORS as e :
47
81
if n > _MAX_RETRIES :
48
82
raise
@@ -51,14 +85,31 @@ def _completion_with_retries(*args, **kwargs):
51
85
time .sleep (secs )
52
86
n += 1
53
87
continue
54
- return response
55
88
56
89
57
90
async def _completion_with_retries_async (* args , ** kwargs ):
58
91
n = 1
59
92
while True :
60
93
try :
61
- response = await openai .ChatCompletion .acreate (* args , ** kwargs )
94
+ if openai_version >= pkg_resources .parse_version ("1.0.0" ):
95
+ if kwargs ["api_type" ].lower () == "azure" :
96
+ from openai import AsyncAzureOpenAI
97
+ client = AsyncAzureOpenAI (
98
+ azure_endpoint = kwargs ["api_base" ],
99
+ api_key = kwargs ["api_key" ],
100
+ api_version = kwargs ["api_version" ]
101
+ )
102
+ response = await client .chat .completions .create (messages = kwargs ["messages" ], model = kwargs ["deployment_id" ], temperature = kwargs ["temperature" ], max_tokens = kwargs ["max_tokens" ])
103
+ else :
104
+ from openai import AsyncOpenAI
105
+ client = AsyncOpenAI (
106
+ api_key = kwargs ["api_key" ],
107
+ )
108
+ response = await client .chat .completions .create (messages = kwargs ["messages" ], model = kwargs ["model" ], temperature = kwargs ["temperature" ], max_tokens = kwargs ["max_tokens" ])
109
+ return response .choices [0 ].message .content , dict (response .usage )
110
+ else :
111
+ response = openai .ChatCompletion .create (* args , ** kwargs )
112
+ return response ["choices" ][0 ].message .content , response ["usage" ]
62
113
except _RETRY_ERRORS as e :
63
114
if n > _MAX_RETRIES :
64
115
raise
@@ -67,7 +118,6 @@ async def _completion_with_retries_async(*args, **kwargs):
67
118
await asyncio .sleep (secs )
68
119
n += 1
69
120
continue
70
- return response
71
121
72
122
class OutputStructure (str , Enum ):
73
123
"""OutputStructure defines what structure the QAs should be written to file in."""
@@ -190,15 +240,16 @@ def _merge_token_usage(self, token_usage: Dict, token_usage2: Dict) -> Dict:
190
240
return {name : count + token_usage [name ] for name , count in token_usage2 .items ()}
191
241
192
242
def _modify_conversation_questions (self , questions ) -> Tuple [List [str ], Dict ]:
193
- response = _completion_with_retries (
243
+ content , usage = _completion_with_retries (
194
244
messages = self ._get_messages_for_modify_conversation (questions ),
195
245
** self ._chat_completion_params ,
196
246
)
197
- modified_questions , _ = self ._parse_qa_from_response (response ["choices" ][0 ].message .content )
198
- # Don't modify first question of conversation
247
+
248
+ modified_questions , _ = self ._parse_qa_from_response (content )
249
+ # Keep proper nouns in first question of conversation
199
250
modified_questions [0 ] = questions [0 ]
200
251
assert len (modified_questions ) == len (questions ), self ._PARSING_ERR_UNEQUAL_Q_AFTER_MOD
201
- return modified_questions , response [ " usage" ]
252
+ return modified_questions , usage
202
253
203
254
@distributed_trace
204
255
@monitor_with_activity (logger , "QADataGenerator.Export" , ActivityType .INTERNALCALL )
@@ -266,13 +317,12 @@ def export_to_file(self, output_path: str, qa_type: QAType, results: Union[List,
266
317
@monitor_with_activity (logger , "QADataGenerator.Generate" , ActivityType .INTERNALCALL )
267
318
def generate (self , text : str , qa_type : QAType , num_questions : int = None ) -> Dict :
268
319
self ._validate (qa_type , num_questions )
269
- response = _completion_with_retries (
320
+ content , token_usage = _completion_with_retries (
270
321
messages = self ._get_messages_for_qa_type (qa_type , text , num_questions ),
271
322
** self ._chat_completion_params ,
272
323
)
273
- questions , answers = self ._parse_qa_from_response (response [ "choices" ][ 0 ]. message . content )
324
+ questions , answers = self ._parse_qa_from_response (content )
274
325
assert len (questions ) == len (answers ), self ._PARSING_ERR_UNEQUAL_QA
275
- token_usage = response ["usage" ]
276
326
if qa_type == QAType .CONVERSATION :
277
327
questions , token_usage2 = self ._modify_conversation_questions (questions )
278
328
token_usage = self ._merge_token_usage (token_usage , token_usage2 )
@@ -282,27 +332,27 @@ def generate(self, text: str, qa_type: QAType, num_questions: int = None) -> Dic
282
332
}
283
333
284
334
async def _modify_conversation_questions_async (self , questions ) -> Tuple [List [str ], Dict ]:
285
- response = await _completion_with_retries_async (
335
+ content , usage = await _completion_with_retries_async (
286
336
messages = self ._get_messages_for_modify_conversation (questions ),
287
337
** self ._chat_completion_params ,
288
338
)
289
- modified_questions , _ = self ._parse_qa_from_response (response ["choices" ][0 ].message .content )
290
- # Don't modify first question of conversation
339
+
340
+ modified_questions , _ = self ._parse_qa_from_response (content )
341
+ # Keep proper nouns in first question of conversation
291
342
modified_questions [0 ] = questions [0 ]
292
343
assert len (modified_questions ) == len (questions ), self ._PARSING_ERR_UNEQUAL_Q_AFTER_MOD
293
- return modified_questions , response [ " usage" ]
344
+ return modified_questions , usage
294
345
295
346
@distributed_trace
296
347
@monitor_with_activity (logger , "QADataGenerator.GenerateAsync" , ActivityType .INTERNALCALL )
297
348
async def generate_async (self , text : str , qa_type : QAType , num_questions : int = None ) -> Dict :
298
349
self ._validate (qa_type , num_questions )
299
- response = await _completion_with_retries_async (
350
+ content , token_usage = await _completion_with_retries_async (
300
351
messages = self ._get_messages_for_qa_type (qa_type , text , num_questions ),
301
352
** self ._chat_completion_params ,
302
353
)
303
- questions , answers = self ._parse_qa_from_response (response [ "choices" ][ 0 ]. message . content )
354
+ questions , answers = self ._parse_qa_from_response (content )
304
355
assert len (questions ) == len (answers ), self ._PARSING_ERR_UNEQUAL_QA
305
- token_usage = response ["usage" ]
306
356
if qa_type == QAType .CONVERSATION :
307
357
questions , token_usage2 = await self ._modify_conversation_questions_async (questions )
308
358
token_usage = self ._merge_token_usage (token_usage , token_usage2 )
0 commit comments