11"""Integration with OpenAI's API."""
2+ import copy
23import functools
3- import warnings
44from dataclasses import asdict , dataclass , field , replace
5- from itertools import zip_longest
6- from typing import Callable , Dict , List , Optional , Set , Tuple , Union
5+ from typing import Callable , Dict , List , Optional , Tuple , Union
76
87import numpy as np
98
@@ -74,7 +73,6 @@ def __init__(
7473 self ,
7574 client ,
7675 config ,
77- tokenizer = None ,
7876 system_prompt : Optional [str ] = None ,
7977 ):
8078 """Create an `OpenAI` instance.
@@ -89,13 +87,9 @@ def __init__(
8987 config
9088 An instance of `OpenAIConfig`. Can be useful to specify some
9189 parameters that cannot be set by calling this class' methods.
92- tokenizer
93- The tokenizer associated with the model the client connects to.
94-
9590 """
9691
9792 self .client = client
98- self .tokenizer = tokenizer
9993 self .config = config
10094
10195 # We count the total number of prompt and generated tokens as returned
@@ -104,6 +98,8 @@ def __init__(
10498 self .prompt_tokens = 0
10599 self .completion_tokens = 0
106100
101+ self .format_sequence = lambda seq : seq
102+
107103 def __call__ (
108104 self ,
109105 prompt : Union [str , List [str ]],
@@ -152,107 +148,17 @@ def __call__(
152148 self .prompt_tokens += prompt_tokens
153149 self .completion_tokens += completion_tokens
154150
155- return response
151+ return self . format_sequence ( response )
156152
157153 def stream (self , * args , ** kwargs ):
158154 raise NotImplementedError (
159155 "Streaming is currently not supported for the OpenAI API"
160156 )
161157
162- def generate_choice (
163- self ,
164- prompt : str ,
165- choices : List [str ],
166- max_tokens : Optional [int ] = None ,
167- system_prompt : Optional [str ] = None ,
168- ) -> str :
169- """Call the OpenAI API to generate one of several choices.
170-
171- Parameters
172- ----------
173- prompt
174- A string or list of strings that will be used to prompt the model
175- choices
176- The list of strings between which we ask the model to choose
177- max_tokens
178- The maximum number of tokens to generate
179- system_prompt
180- The content of the system message that precedes the user's prompt.
181-
182- """
183- if self .tokenizer is None :
184- raise ValueError (
185- "You must initialize the `OpenAI` class with a tokenizer to use `outlines.generate.choice`"
186- )
187-
188- config = replace (self .config , max_tokens = max_tokens )
189-
190- greedy = False
191- decoded : List [str ] = []
192- encoded_choices_left : List [List [int ]] = [
193- self .tokenizer .encode (word ) for word in choices
194- ]
195-
196- while len (encoded_choices_left ) > 0 :
197- max_tokens_left = max ([len (tokens ) for tokens in encoded_choices_left ])
198- transposed_choices_left : List [Set ] = [
199- {item for item in subset if item is not None }
200- for subset in zip_longest (* encoded_choices_left )
201- ]
202-
203- if not greedy :
204- mask = build_optimistic_mask (transposed_choices_left )
205- else :
206- mask = {}
207- for token in transposed_choices_left [0 ]: # build greedy mask
208- mask [token ] = 100
209-
210- if len (mask ) == 0 :
211- break
212-
213- config = replace (config , logit_bias = mask , max_tokens = max_tokens_left )
214-
215- response , prompt_tokens , completion_tokens = generate_chat (
216- prompt , system_prompt , self .client , config
217- )
218- self .prompt_tokens += prompt_tokens
219- self .completion_tokens += completion_tokens
220-
221- encoded_response = self .tokenizer .encode (response )
222-
223- if encoded_response in encoded_choices_left :
224- decoded .append (response )
225- break
226- else :
227- (
228- encoded_response ,
229- encoded_choices_left ,
230- ) = find_response_choices_intersection (
231- encoded_response , encoded_choices_left
232- )
233-
234- if len (encoded_response ) == 0 :
235- greedy = True # next iteration will be "greedy"
236- continue
237- else :
238- decoded .append ("" .join (self .tokenizer .decode (encoded_response )))
239-
240- if len (encoded_choices_left ) == 1 : # only one choice left
241- choice_left = self .tokenizer .decode (encoded_choices_left [0 ])
242- decoded .append (choice_left )
243- break
244-
245- greedy = False # after each success, stay with (or switch to) "optimistic" approach
246-
247- prompt = prompt + "" .join (decoded )
248-
249- choice = "" .join (decoded )
250-
251- return choice
252-
253- def generate_json (self ):
254- """Call the OpenAI API to generate a JSON object."""
255- raise NotImplementedError
158+ def new_with_replacements (self , ** kwargs ):
159+ new_instance = copy .copy (self )
160+ new_instance .config = replace (new_instance .config , ** kwargs )
161+ return new_instance
256162
257163 def __str__ (self ):
258164 return self .__class__ .__name__ + " API"
@@ -313,81 +219,6 @@ async def call_api(prompt, system_prompt, config):
313219 return results , usage ["prompt_tokens" ], usage ["completion_tokens" ]
314220
315221
316- def find_longest_intersection (response : List [int ], choice : List [int ]) -> List [int ]:
317- """Find the longest intersection between the response and the choice."""
318- for i , (token_r , token_c ) in enumerate (zip_longest (response , choice )):
319- if token_r != token_c :
320- return response [:i ]
321-
322- return response
323-
324-
325- def find_response_choices_intersection (
326- response : List [int ], choices : List [List [int ]]
327- ) -> Tuple [List [int ], List [List [int ]]]:
328- """Find the longest intersection between the response and the different
329- choices.
330-
331- Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices
332- `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the
333- intersection, and `[[]]` as the list of choices left.
334-
335- Parameters
336- ----------
337- response
338- The model's response
339- choices
340- The remaining possible choices
341-
342- Returns
343- -------
344- A tuple that contains the longest intersection between the response and the
345- different choices, and the choices which start with this intersection, with the
346- intersection removed.
347-
348- """
349- max_len_prefix = 0
350- choices_left = []
351- longest_prefix = []
352- for i , choice in enumerate (choices ):
353- # Find the longest intersection between the response and the choice.
354- prefix = find_longest_intersection (response , choice )
355-
356- if len (prefix ) > max_len_prefix :
357- max_len_prefix = len (prefix )
358- choices_left = [choice [len (prefix ) :]]
359- longest_prefix = prefix
360-
361- elif len (prefix ) == max_len_prefix :
362- choices_left .append (choice [len (prefix ) :])
363-
364- return longest_prefix , choices_left
365-
366-
367- def build_optimistic_mask (
368- transposed : List [Set [int ]], max_mask_size : int = 300
369- ) -> Dict [int , int ]:
370- """We build the largest mask possible.
371-
372- Tokens are added from left to right, so if the encoded choices are e.g.
373- `[[1,2], [3,4]]`, `1` and `3` will be added before `2` and `4`.
374-
375- Parameters
376- ----------
377- transposed
378- A list of lists that contain the nth token of each choice.
379-
380- """
381- mask : Dict [int , int ] = {}
382- for tokens in transposed :
383- for token in tokens :
384- if len (mask ) == max_mask_size :
385- return mask
386- mask [token ] = 100
387-
388- return mask
389-
390-
391222def error_handler (api_call_fn : Callable ) -> Callable :
392223 """Handle OpenAI API errors and missing API key."""
393224
@@ -427,11 +258,10 @@ def openai_model(
427258 ** openai_client_params ,
428259):
429260 try :
430- import tiktoken
431261 from openai import AsyncOpenAI
432262 except ImportError :
433263 raise ImportError (
434- "The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' OpenAI integration."
264+ "The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
435265 )
436266
437267 if config is not None :
@@ -441,15 +271,7 @@ def openai_model(
441271
442272 client = AsyncOpenAI (** openai_client_params )
443273
444- try :
445- tokenizer = tiktoken .encoding_for_model (model_name )
446- except KeyError :
447- warnings .warn (
448- f"Could not find a tokenizer for model { model_name } . Using default cl100k_base."
449- )
450- tokenizer = tiktoken .get_encoding ("cl100k_base" )
451-
452- return OpenAI (client , config , tokenizer )
274+ return OpenAI (client , config )
453275
454276
455277def azure_openai (
@@ -459,11 +281,10 @@ def azure_openai(
459281 ** azure_openai_client_params ,
460282):
461283 try :
462- import tiktoken
463284 from openai import AsyncAzureOpenAI
464285 except ImportError :
465286 raise ImportError (
466- "The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' Azure OpenAI integration."
287+ "The `openai` library needs to be installed in order to use Outlines' Azure OpenAI integration."
467288 )
468289
469290 if config is not None :
@@ -473,12 +294,4 @@ def azure_openai(
473294
474295 client = AsyncAzureOpenAI (** azure_openai_client_params )
475296
476- try :
477- tokenizer = tiktoken .encoding_for_model (model_name or deployment_name )
478- except KeyError :
479- warnings .warn (
480- f"Could not find a tokenizer for model { model_name or deployment_name } . Using default cl100k_base."
481- )
482- tokenizer = tiktoken .get_encoding ("cl100k_base" )
483-
484- return OpenAI (client , config , tokenizer )
297+ return OpenAI (client , config )
0 commit comments