diff --git a/docs/examples/chain_of_density.md b/docs/examples/chain_of_density.md index 8a822fc2e..bd2fea17c 100644 --- a/docs/examples/chain_of_density.md +++ b/docs/examples/chain_of_density.md @@ -67,13 +67,13 @@ model = outlines.from_transformers( transformers.AutoTokenizer.from_pretrained(MODEL_NAME) ) prompt = chain_of_density(article=article) -result = model(prompt, Summaries, max_new_tokens=2000) +result = model(prompt, Summaries, max_new_tokens=2000).content ``` We can now check the results: ```python -print(result) +print(result.content) # {'summaries': [ # { # 'missing_entities': 'English mathematician, cryptanalyst, philosopher', diff --git a/docs/examples/chain_of_thought.md b/docs/examples/chain_of_thought.md index 207363764..4d7c14ae2 100644 --- a/docs/examples/chain_of_thought.md +++ b/docs/examples/chain_of_thought.md @@ -111,7 +111,7 @@ We obtain a series of intermediate reasoning steps as well as the conclusion: ```python import json -json_response = json.loads(response) +json_response = json.loads(response.content) print(json_response["reasoning"]) print(json_response["conclusion"]) diff --git a/docs/examples/classification.md b/docs/examples/classification.md index 18a933b47..d5530276b 100644 --- a/docs/examples/classification.md +++ b/docs/examples/classification.md @@ -52,7 +52,7 @@ prompts = [customer_support(request=request) for request in requests] We can now ask the model to classify the requests: ```python -labels = generator(prompts) +labels = generator(prompts).content print(labels) # ['URGENT', 'STANDARD'] ``` @@ -79,7 +79,7 @@ We can then create a generator with the Pydantic model we just defined and call ```python generator = outlines.Generator(model, Classification) -labels = generator(prompts) +labels = generator(prompts).content print(labels) # ['{"label":"URGENT"}', '{ "label": "STANDARD" }'] ``` diff --git a/docs/examples/dating_profiles.md b/docs/examples/dating_profiles.md index 5852e6df9..569753c79 100644 --- a/docs/examples/dating_profiles.md +++ b/docs/examples/dating_profiles.md @@ -165,7 +165,7 @@ it's a good excuse for a date. I watch the latest series because I'm paying, with my hard-earned money, for every streaming service.""" prompt = dating_profile_prompt(description=new_description, examples=samples) -profile = model(prompt, DatingProfile) +profile = model(prompt, DatingProfile).content parsed_profile = DatingProfile.model_validate_json(json.loads(profile)) ``` diff --git a/docs/examples/deploy-using-bentoml.md b/docs/examples/deploy-using-bentoml.md index c0ddf7e2d..1aa94a42d 100644 --- a/docs/examples/deploy-using-bentoml.md +++ b/docs/examples/deploy-using-bentoml.md @@ -154,7 +154,7 @@ We then need to define an HTTP endpoint using `@bentoml.api` to decorate the met from outlines.types import JsonSchema generator = outlines.Generator(self.model, JsonSchema(json_schema)) - character = generator(prompt) + character = generator(prompt).content return json.loads(character) ``` @@ -200,7 +200,7 @@ with bentoml.SyncHTTPClient("http://localhost:3000") as client: response = client.generate( prompt="Give me a character description" ) - print(response) + print(response.content) ``` diff --git a/docs/examples/deploy-using-cerebrium.md b/docs/examples/deploy-using-cerebrium.md index caf1c96b0..e515940f3 100644 --- a/docs/examples/deploy-using-cerebrium.md +++ b/docs/examples/deploy-using-cerebrium.md @@ -110,7 +110,7 @@ def generate( character = generator( f"[INST]Give me a character description. Describe {prompt}.[/INST]" - ) + ).content return character ``` diff --git a/docs/examples/deploy-using-modal.md b/docs/examples/deploy-using-modal.md index b6d9c7582..ad9afa24b 100644 --- a/docs/examples/deploy-using-modal.md +++ b/docs/examples/deploy-using-modal.md @@ -161,7 +161,7 @@ def generate( # by models, so make sure to check the model's documentation. character = generator( f"[INST]Give me a character description. Describe {prompt}.[/INST]" - ) + ).content # Print out the generated character. print(character) diff --git a/docs/examples/earnings-reports.md b/docs/examples/earnings-reports.md index 645726b07..6fc44d8c1 100644 --- a/docs/examples/earnings-reports.md +++ b/docs/examples/earnings-reports.md @@ -246,7 +246,7 @@ Provide the prompt to the model and run it: csv_data = csv_extractor( extract_financial_data_prompt(columns_to_extract, income_statement), max_new_tokens=1024, -) +).content print(csv_data) ``` diff --git a/docs/examples/extract_event_details.py b/docs/examples/extract_event_details.py index cd90bc439..424135aa3 100644 --- a/docs/examples/extract_event_details.py +++ b/docs/examples/extract_event_details.py @@ -45,7 +45,7 @@ class Event(BaseModel): prompt = prompt_template(now=now, message=message) # Extract the event information -event = generator(prompt) +event = generator(prompt).content # type: ignore # Print the current date and time print(f"Today: {now}") diff --git a/docs/examples/extraction.md b/docs/examples/extraction.md index cfb39ca5c..b8b54cc25 100644 --- a/docs/examples/extraction.md +++ b/docs/examples/extraction.md @@ -79,7 +79,7 @@ prompts = [take_order(order=order) for order in orders] generator = outlines.Generator(model, Order) results = generator(prompts) -print(results) +print(results.content) # ['{"pizza": "Pepperoni", "number": 2}', # '{"pizza": "Margherita", "number": 12}'] ``` diff --git a/docs/examples/knowledge_graph_extraction.md b/docs/examples/knowledge_graph_extraction.md index 39b1c1d6b..bbd54f8c2 100644 --- a/docs/examples/knowledge_graph_extraction.md +++ b/docs/examples/knowledge_graph_extraction.md @@ -123,7 +123,7 @@ response = generator(prompt, max_tokens=1024, temperature=0, seed=42) We obtain the nodes and edges of the knowledge graph: ```python -print(response) +print(response.content) # {"nodes":[{"id":1,"label":"Alice","property":"loves,hates"}, # {"id":2,"label":"Bob","property":"loved_by"}, # {"id":3,"label":"Charlie","property":"hated_by"}], @@ -137,12 +137,14 @@ print(response) We can use the [Graphviz library](https://graphviz.readthedocs.io/en/stable/) to visualize the generated knowledge graph. For detailed installation instructions, see [here](https://graphviz.readthedocs.io/en/stable/#installation). ```python +import json from graphviz import Digraph +json_response = json.loads(response.content) dot = Digraph() -for node in response["nodes"]: +for node in json_response["nodes"]: dot.node(str(node["id"]), node["label"], shape='circle', width='1', height='1') -for edge in response["edges"]: +for edge in json_response["edges"]: dot.edge(str(edge["source"]), str(edge["target"]), label=edge["label"]) dot.render('knowledge-graph.gv', view=True) diff --git a/docs/examples/models_playing_chess.md b/docs/examples/models_playing_chess.md index b8d89958f..80a9e9f3f 100644 --- a/docs/examples/models_playing_chess.md +++ b/docs/examples/models_playing_chess.md @@ -66,7 +66,7 @@ board_state = " " turn_number = 0 while not board.is_game_over(): regex_pattern = legal_moves_regex(board) - structured = model(prompt + board_state, regex_pattern) + structured = model(prompt + board_state, regex_pattern).content move = board.parse_san(structured) if turn_number % 2 == 0 : # It's White's turn diff --git a/docs/examples/qa-with-citations.md b/docs/examples/qa-with-citations.md index 756988c3f..f9488b019 100644 --- a/docs/examples/qa-with-citations.md +++ b/docs/examples/qa-with-citations.md @@ -72,7 +72,7 @@ import json generator = outlines.Generator(model, Users) response = generator("Create 5 fake users", max_tokens=1024, temperature=0, seed=42) -response = json.loads(response) +response = json.loads(response.content) print(response['users']) # [{'id': 1, 'first_name': 'John', 'last_name': 'Doe', 'age': 25}, # {'id': 2, 'first_name': 'Jane', 'last_name': 'Doe', 'age': 30}, @@ -164,7 +164,7 @@ I also started the Data Science club at the University of Waterloo and I was the generator = outlines.Generator(model, QuestionAnswer) prompt = hermes_prompt(question=question, context=context, schema=schema) response = generator(prompt, max_tokens=1024, temperature=0, seed=42) -print(response) +print(response.content) # {"question": "What did the author do during college?", "answer": "The author studied Computational Mathematics and physics in university and was also involved in starting the Data Science club, serving as its president for 2 years.", "citations": ["I went to an arts high school but in university I studied Computational Mathematics and physics.", "I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years."]} ``` @@ -222,7 +222,7 @@ for question, context in [ prompt = hermes_prompt(question=question, context=context, schema=schema) generator = outlines.Generator(model, QuestionAnswer) response = generator(prompt, max_tokens=1024, temperature=0, seed=42) - response = json.loads(response) + response = json.loads(response.content) print(question) print(response['answer']) print(response['citations']) diff --git a/docs/examples/react_agent.md b/docs/examples/react_agent.md index 55f150eb0..30e796bd9 100644 --- a/docs/examples/react_agent.md +++ b/docs/examples/react_agent.md @@ -133,7 +133,7 @@ class ChatBot: def execute(self): generator = outlines.Generator(model, Decision) result = generator(self.prompt, max_tokens=1024, temperature=0, seed=42) - return result + return result.content ``` We define a query function: diff --git a/docs/examples/read-pdfs.md b/docs/examples/read-pdfs.md index 81e71f467..3ccf65c84 100644 --- a/docs/examples/read-pdfs.md +++ b/docs/examples/read-pdfs.md @@ -209,7 +209,7 @@ page_summary_generator = outlines.Generator(model, PageSummary) for image in images: result = page_summary_generator({"text": prompt, "images": image}) - print(result) + print(result.content) ``` ### Regular expressions to extract the arxiv paper identifier @@ -317,7 +317,7 @@ categorizer = outlines.Generator(model, Literal["llms", "cell biology", "other"] # Categorize the paper category = categorizer({"text": categorization_instruction, "images": images[0]}) -print(category) +print(category.content) ``` Which should return: @@ -357,7 +357,7 @@ two_image_prompt = tf_processor.apply_chat_template( generator = outlines.Generator(model, Literal["hot dog", "not hot dog"]) result = generator({"text": two_image_prompt, "images": [images[0], images[1]]}) -print(result) +print(result.content) ``` Using the first to pages of the paper (they are not images of hot dogs), we should get diff --git a/docs/examples/receipt-digitization.md b/docs/examples/receipt-digitization.md index 994024a9f..419ef65d5 100644 --- a/docs/examples/receipt-digitization.md +++ b/docs/examples/receipt-digitization.md @@ -225,7 +225,7 @@ result = receipt_summary_generator( {"text": prompt, "images": image}, max_new_tokens=1024 ) -print(result) +print(result.content) ``` ## Output diff --git a/docs/examples/simtom.md b/docs/examples/simtom.md index 53820258e..bf94b2894 100644 --- a/docs/examples/simtom.md +++ b/docs/examples/simtom.md @@ -87,7 +87,7 @@ perspective_prompt = perspective_taking(story=story, character=character) # Call Mistral 7B with the first prompt generator = outlines.Generator(model, PerspectiveTaking) -perspective = generator(perspective_prompt, max_new_tokens=1024) +perspective = generator(perspective_prompt, max_new_tokens=1024).content print(perspective) # {'character': 'Aria', 'events': ['1 Aria entered the front_yard.', '3 The grapefruit is in the green_bucket.', '4 Aria moved the grapefruit to the blue_container.']} @@ -104,7 +104,7 @@ sim_prompt = simulation(events=json.loads(perspective)["events"], name=character generator = outlines.Generator(model, Simulation) result = generator(sim_prompt, max_new_tokens=1024) -print(result) +print(result.content) # {'answer': 'green_bucket'} ``` diff --git a/docs/examples/structured_generation_workflow.md b/docs/examples/structured_generation_workflow.md index 524283f84..cebc9af0f 100644 --- a/docs/examples/structured_generation_workflow.md +++ b/docs/examples/structured_generation_workflow.md @@ -33,7 +33,7 @@ With our prompt ready we can now generate 10 example phone numbers ```python phone_generator_unstruct = outlines.Generator(model) for _ in range(3): - print(phone_generator_unstruct(prompt_phone, max_new_tokens=12)) + print(phone_generator_unstruct(prompt_phone, max_new_tokens=12).content) ``` > I'd be happy to help you generate a realistic phone\ @@ -105,7 +105,7 @@ We're ready to see if structured generation can make an improvement over our ini phone_generator_v1 = outlines.Generator(model, phone_regex_1) for _ in range(3): - print(phone_generator_v1(prompt_phone)) + print(phone_generator_v1(prompt_phone).content) ``` > (206) 555-1234\ (206) 555-1234\ @@ -146,7 +146,7 @@ Now that we've validated, let's generate with this new regex! phone_generator_v2 = outlines.Generator(model, phone_regex_2) for _ in range(3): - print(phone_generator_v2(prompt_phone)) + print(phone_generator_v2(prompt_phone).content) ``` > (206) 867-5309\ @@ -178,9 +178,9 @@ if not re.match(phone_regex_3_error, phone_number): else: matched_string = re.match(phone_regex_3_error, phone_number)[0] if matched_string == phone_number: - print("Successful match") + print("Successful match") else: - print(f"Error {matched_string} != {phone_number}") + print(f"Error {matched_string} != {phone_number}") ``` This prints out: > Error (206) 386-463 != (206) 386-4636 @@ -192,7 +192,7 @@ phone_regex_3_fixed = Regex(r'\([0-9]{3}\) [2-4][7-9][4-6]-[3-6][2-8][1-4][6-9]' phone_generator_v3 = outlines.Generator(model, phone_regex_3_fixed) for _ in range(3): - print(phone_generator_v3(prompt_phone)) + print(phone_generator_v3(prompt_phone).content) ``` >(206) 494-3216\ diff --git a/docs/features/advanced/backends.md b/docs/features/advanced/backends.md index c8b3fa665..27a1d6e60 100644 --- a/docs/features/advanced/backends.md +++ b/docs/features/advanced/backends.md @@ -25,11 +25,11 @@ model = outlines.from_transformers( ) result = model("What is the capital of France?", output_type, backend="llguidance") -print(result) # 'Paris' +print(result.content) # 'Paris' generator = outlines.Generaor(model, output_type) result = generator("What is the capital of France?", backend="xgrammar") -print(result) # 'Paris' +print(result.content) # 'Paris' ``` If you do not provide a value for the `backend` argument, the default value will be used. The default value depends on the type of output type: diff --git a/docs/features/advanced/logits_processors.md b/docs/features/advanced/logits_processors.md index 4a05609e2..235e826da 100644 --- a/docs/features/advanced/logits_processors.md +++ b/docs/features/advanced/logits_processors.md @@ -44,7 +44,7 @@ logits_processor = RegexLogitsProcessor(r"U\+[0-9A-Fa-f]{4,6}", model.tokenizer, generator = Generator(model, processor=logits_processor) response = generator("What's the unicode for the hugging face emoji") -print(response) # U+1F917 +print(response.content) # U+1F917 ``` ## Creating Custom Logits Processors @@ -95,5 +95,5 @@ formatted_prompt = tf_tokenizer.apply_chat_template( generator = Generator(model, processor=logits_processor) response = generator(formatted_prompt) -print(response) # "101111" +print(response.content) # "101111" ``` diff --git a/docs/features/core/generator.md b/docs/features/core/generator.md index e407ed37b..289b4c4b9 100644 --- a/docs/features/core/generator.md +++ b/docs/features/core/generator.md @@ -47,7 +47,7 @@ generator = Generator(model) result = generator("Write a short poem about AI.") # Print the result -print(result) +print(result.content) ``` ## Structured Generation @@ -77,7 +77,7 @@ generator = Generator(model, BookRecommendation) result = generator("Recommend a science fiction book.") # Parse the JSON result into a Pydantic model -book = BookRecommendation.model_validate_json(result) +book = BookRecommendation.model_validate_json(result.content) print(f"{book.title} by {book.author} ({book.year})") ``` @@ -109,7 +109,7 @@ result = generator( ## Return Value -The generator always returns a raw string containing the generated text. When generating structured outputs, you need to parse this string into the desired format. +The generator returns an `Output` instance (or a iterator containing `StreamingOutput` instances in case of streaming). The `content` field contains the generated text as a string. When generating structured outputs, you need to parse this string into the desired format. Unlike in Outlines v0, where the return type could be a parsed object, in v1 you are responsible for parsing the output when needed: @@ -126,7 +126,7 @@ generator = Generator(model, Person) result = generator("Generate a person:") # Parse the result yourself -person = Person.model_validate_json(result) +person = Person.model_validate_json(result.content) ``` ::: outlines.generator.Generator diff --git a/docs/features/core/inputs.md b/docs/features/core/inputs.md index 47c4969c2..79f9434f0 100644 --- a/docs/features/core/inputs.md +++ b/docs/features/core/inputs.md @@ -32,7 +32,7 @@ model = outlines.from_transformers( # Simple text prompt response = model("What's the capital of France?", max_new_tokens=20) -print(response) # 'Paris' +print(response.content) # 'Paris' ``` ## Multimodal Inputs (Vision) @@ -76,16 +76,22 @@ prompt = [ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` ## Chat Inputs For conversational models, you can use the `Chat` class to provide a conversation history with multiple messages. -A `Chat` instance is instantiated with an optional list of messages. Each message must be a dictionary containing two mandatory keys: -- `role`: must be one of `system`, `assistant` or `user` -- `content`: must be either a string or a multimodal input (if the model supports it) +A `Chat` is instantiated with an optional list of messages. The type of each message is defined by the value of the mandatory `role` key. There are 4 types of messages that each have their associated keys: +- `system`: system instructions to give context to the LLM on the task to perform. The only other key is `content` (mandatory). +- `user`: a message from you in the conversation. The only other key is `content` (mandatory). +- `assistant`: a response from the LLM. The other keys are `content` and `tool_calls` (a list of `ToolCall` instances). At least one of those two must be provided. +- `tool`: a tool call response. The other keys are `content` (mandatory), `tool_name` and `tool_call_id`. Depending on the models you are using, one of those two is mandatory. + +Support for the various message types and fields described above depends on the capabilities of the model you are using. Tool calling is limited to a few models at the moment for instance. To know more about tools, consult the dedicated section on [tools](./tools.md). + +An `Output` instance returned by a model can also be added to a `Chat`. It will automatically be turned into a user message. To know more about model outputs, consult the dedicated section on [outputs](./outputs.md). For instance: @@ -149,13 +155,15 @@ print(prompt) # {'role': 'assistant', 'content': 'Excellent, thanks!'} ``` -Finally, there are three convenience method to easily add a message: +There are four convenience method to easily add a message: -- add_system_message -- add_user_message -- add_assistant_message +- `add_system_message` +- `add_user_message` +- `add_assistant_message` +- `add_tool_message` +- `add_output` -As the role is already set, you only need to provide the content. +As the role is already set, you only need to provide values for the other keys of the message type, except for the `add_output` for which you would just provide the model call output. For instance: @@ -200,5 +208,5 @@ prompts = [ # Call it to generate text result = model.batch(prompts, max_new_tokens=20) -print(result) # ['Vilnius', 'Riga', 'Tallinn'] +print([item.content for item in result]) # ['Vilnius', 'Riga', 'Tallinn'] ``` diff --git a/docs/features/core/output_types.md b/docs/features/core/output_types.md index 85451ae95..07919eccc 100644 --- a/docs/features/core/output_types.md +++ b/docs/features/core/output_types.md @@ -48,9 +48,9 @@ def create_character() -> Character: With an Outlines model, you can generate text that respects the type hints above by providing those as the output type: ```python -model("How many minutes are there in one hour", int) # "60" -model("Pizza or burger", Literal["pizza", "burger"]) # "pizza" -model("Create a character", Character, max_new_tokens=100) # '{"name": "James", "birth_date": "1980-05-10)", "skills": ["archery", "negotiation"]}' +model("How many minutes are there in one hour", int).content # "60" +model("Pizza or burger", Literal["pizza", "burger"]).content # "pizza" +model("Create a character", Character, max_new_tokens=100).content # '{"name": "James", "birth_date": "1980-05-10)", "skills": ["archery", "negotiation"]}' ``` An important difference with function type hints though is that an Outlines generator always returns a string. @@ -61,8 +61,8 @@ For instance: ```python result = model("Create a character", Character, max_new_tokens=100) casted_result = Character.model_validate_json(result) -print(result) # '{"name": "Aurora", "birth_date": "1990-06-15", "skills": ["Stealth", "Diplomacy"]}' -print(casted_result) # name=Aurora birth_date=datetime.date(1990, 6, 15) skills=['Stealth', 'Diplomacy'] +print(result).content # '{"name": "Aurora", "birth_date": "1990-06-15", "skills": ["Stealth", "Diplomacy"]}' +print(casted_result).content # name=Aurora birth_date=datetime.date(1990, 6, 15) skills=['Stealth', 'Diplomacy'] ``` ## Output Type Categories diff --git a/docs/features/core/outputs.md b/docs/features/core/outputs.md new file mode 100644 index 000000000..4d58688ab --- /dev/null +++ b/docs/features/core/outputs.md @@ -0,0 +1,66 @@ +--- +title: Outputs +--- + +# Outputs + +## Overview + +Outlines uses two objcets to contain model response: `Ouptut` and `StreamingOutput`. + +They both have two fields: + +- `content`: the raw text reponse returned by the model +- `tool_calls`: a list of `ToolCallOutput` or `StreamingToolCallOutput` instances if the model decided to call a tool instead of giving a response directly. This field can only have a value if you provided a list of tools to the model in the first place. + +To access the text response from the model, you would thus typically only do `reponse.output`. In the case of streaming, it would give you a chunk of the response. + +## Chat + +If you are using a `Chat` input to call the model, you can add the `Output` you received from the model to your `Chat` instance to add a new message that will be part of the conversation provided to the model the next time you can it. + +For instance: + +```python +import transformers +import outlines +from outlines.inputs import Chat, Image + +MODEL_ID = "microsoft/Phi-3-mini-4k-instruct" + +model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained(MODEL_ID), + transformers.AutoTokenizer.from_pretrained(MODEL_ID), +) + +# Initialize the chat with a system message. +chat_prompt = Chat([ + {"role": "system", "content": "You are a helpful assistant."}, +]) + +# Add a user message to the chat. +chat_prompt.add_user_message("What's the capital of Latvia?") + +# Call the model with the chat input. +response = model(chat_prompt) +print(response.content) # 'The capital of Latvia is Riga.' + +# Add the output to the chat. +chat_prompt.add_output(response) + +# Add another user message to the chat and call the model again. +chat_prompt.add_user_message("How many inhabitants does it have?") +response = model(chat_prompt) +print(response.content) # '600,000' +``` + +## Tool Calls + +As described above, the output you receive from the model can contain a list of `ToolCallOutput` or `StreamingToolCallOutput` instances for the `tool_calls` field if the model decided to first call tools. + +A `ToolCallOutput` or `StreamingToolCallOutput` contains three fields: +- `name`: the name of the tool to call +- `id`: the id of the tool call to make. If provided, it should typically be included in the `ToolMessage` containing the tool response you would add to the `Chat` +- `args`: the arguments to provide to the tool to call. This is a dictionnary for regular call and a string for streaming calls (as it could contain only a chunk of the whole args) + +See the section on [tools](./tools.md) for an explanation on how to use the `ToolCallOutput` to make a tool call. diff --git a/docs/features/core/tools.md b/docs/features/core/tools.md new file mode 100644 index 000000000..1cd70f57a --- /dev/null +++ b/docs/features/core/tools.md @@ -0,0 +1,164 @@ +--- +title: Tools +--- + +# Tools + +## Overview + +Some models support tool calling, meaning that instead of directly providing its final response, the model can require to call tools you have defined and would later use the tool response in its final response. Tool calling typically goes along providing a `Chat` input as it implies a multiturn conversation with the model. + +For the moment, tool calling is supported by three Outlines models: + +- `Anthropic` +- `Gemini` +- `OpenAI` + +## Tool Definition + +Using tool calling starts with defining the tools that the model can call. There are three formats currently supported as described below. + +Once defined, the tools must be provided in a list to the `tools` keyword argument to the `Generator` constructor or to the text generation methods of a model. As such, the interface for `tools` is very similar to that of the `output_type`. + +#### ToolDef + +A tool can first by defined as a dictionnary. A `ToolDef` dict must contain the following keys: + +- `name`: The name of the tool +- `description`: A description of the tool to help the LLM understand its use +- `parameters`: A dictionnary containing the paramters of the tool, using the JSON properties format. If the LLM decides to call the tool, it will provide values for the parameters +- `required`: A list of parameters that are mandatory. All those parameters must be included in the `parameters` key described above + +For instance: + +```python +import openai +from outlines import from_openai +from outlines.inputs import Chat +from outlines.tools import ToolDef + +client = openai.OpenAI() +model = from_openai(client, "gpt-4o") + +chat = Chat([ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather in Tokyo?"}, +]) + +weather_tool = ToolDef( + name="get_weather", + description="Give the weather for a given city, and optionally for a specific hour of the day", + parameters={"city": {"type": "string"}, "hour": {"type": "integer"}}, + required=["city"], +) + +response = model(chat, tools=[weather_tool]) +print(response.tool_calls) # [ToolCallOutput(name='get_weather', id='call_p7ToNwgrgoEk9poN7PXTELT5', args={'city': 'Tokyo'})] +``` + +#### Function + +A python function can be used as a tool definition. The `description` would then correspond to the docstring while the `parameters` and `required` would be deduced from the signature. + +```python +import openai +from outlines import from_openai +from outlines.inputs import Chat +from typing import Optional + +client = openai.OpenAI() +model = from_openai(client, "gpt-4o") + +chat = Chat([ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather in Tokyo?"}, +]) + +def get_weather(city: str, hour: Optional[int] = None): + """Give the weather for a given city, and optionally for a specific hour of the day""" + pass + +response = model(chat, tools=[get_weather]) +print(response.tool_calls) # [ToolCallOutput(name='get_weather', id='call_IdsfmBss6XhiBDbchTqp3HHz', args={'city': 'Tokyo'})] +``` + +#### Pydantic model + +Lastly, you can use a Pydantic model to define the interface of your tool. + +```python +import openai +from outlines import from_openai +from outlines.inputs import Chat +from pydantic import BaseModel +from typing import Optional + +client = openai.OpenAI() +model = from_openai(client, "gpt-4o") + +chat = Chat([ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather in Tokyo?"}, +]) + +class GetWeather(BaseModel): + """Give the weather for a given city, and optionally for a specific hour of the day""" + city: str + hour: Optional[int] = None + +response = model(chat, tools=[GetWeather]) +print(response.tool_calls) # [ToolCallOutput(name='GetWeather', id='call_KWfADMEr6dnDDcw1m2dllRvq', args={'city': 'Tokyo'})] +``` + +## Tool Calls and Responses + +If the model decides to call a tool, you'll get a value for the `tool_calls` attribute of the `Output` received. This value is a `OutputToolCall` instance containing three attributes: + +- `name`: The name of the tool to call +- `id`: The id of the tool call to be able to easily link the tool call and the tool response +- `args`: A dictionnary containing for each parameter required by the tool the value provided by the LLM + +You should use the `name` and the `args` to call your tool yourself and get its reponse. Afterward, you can add to your chat the `Output` you first receive and a `ToolMessage` before being able to call the model again to continue the conversation. + +For instance: + +```python +import openai +from outlines import Generator, from_openai +from outlines.inputs import Chat +from typing import Optional + +# Our tool +def get_weather(city: str, hour: Optional[int] = None): + """Give the weather for a given city, and optionally for a specific hour of the day""" + return "20 degrees" + +client = openai.OpenAI() +model = from_openai(client, "gpt-4o") +generator = Generator(model, tools=[get_weather]) + +chat = Chat([ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather in Tokyo?"}, +]) + +response = generator(chat) +print(response.tool_calls) # [ToolCallOutput(name='get_weather', id='call_NlIGHr8HoiVgSZfOJ7Y5xz35', args={'city': 'Tokyo'})] + +# Add the model response to the chat +chat.add_output(response) + +# Call the tool with the parameters given by the model and add a tool message to the chat +tool_call = response.tool_calls[0] +tool_response = get_weather(**tool_call.args) +chat.add_tool_message( + content=tool_response, + tool_name=tool_call.name, + tool_call_id=tool_call.id +) + +response = generator(chat) +print(response.content) # The weather in Tokyo is currently 20 degrees. +``` + +When using streaming, the response would be a `StreamingOutput` and the `tool_calls` value a list of `StreamingOutputToolCall`. The only difference compared to what's the describe above is that the `args` field would be a string as the value is received by chunks. You need to concatenate the chunks together to get the full `args` to use to call the tool. diff --git a/docs/features/models/anthropic.md b/docs/features/models/anthropic.md index 403040ca4..2129a434e 100644 --- a/docs/features/models/anthropic.md +++ b/docs/features/models/anthropic.md @@ -53,7 +53,7 @@ model = outlines.from_anthropic( # Call it to generate text response = model("What's the capital of Latvia?", max_tokens=20) -print(response) # 'Riga' +print(response.content) # 'Riga' ``` #### Vision @@ -89,7 +89,7 @@ prompt = [ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Chat @@ -129,7 +129,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Streaming @@ -150,7 +150,38 @@ model = outlines.from_anthropic( # Stream the response for chunk in model.stream("Tell me a short story about a cat.", max_tokens=50): - print(chunk) # 'Once...' + print(chunk.content) # 'Once...' +``` + +#### Tools + +Anthropic supports tool calling. To use it, provide a list of `tools` to the model. + +For instance: + +```python +from anthropic import Anthropic +from outlines import from_anthropic +from outlines.inputs import Chat +from typing import Optional + +# Our tool +def get_weather(city: str, hour: Optional[int] = None): + """Give the weather for a given city, and optionally for a specific hour of the day""" + return "20 degrees" + +# Create the model +model = from_anthropic( + Anthropic(), + "claude-3-5-sonnet-latest" +) + +# Call the model with the tool defined above +chat = Chat([ + {"role": "user", "content": "What's the weather in Tokyo?"}, +]) +response = model(chat, tools=[get_weather], max_tokens=100) +print(response.tool_calls) # [ToolCallOutput(name='get_weather', id='toolu_01WDUo65vCXkrmjD3Yehbc5v', args={'city': 'Tokyo'})] ``` ## Inference arguments diff --git a/docs/features/models/dottxt.md b/docs/features/models/dottxt.md index 306ef2c62..4cd610535 100644 --- a/docs/features/models/dottxt.md +++ b/docs/features/models/dottxt.md @@ -63,8 +63,8 @@ model = outlines.from_dottxt( # Generate structured text result = model("Create a character", Character) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` ## Inference arguments diff --git a/docs/features/models/gemini.md b/docs/features/models/gemini.md index ee48041aa..a9bc0000d 100644 --- a/docs/features/models/gemini.md +++ b/docs/features/models/gemini.md @@ -49,7 +49,7 @@ model = outlines.from_gemini( # Call it to generate text result = model("What's the capital of Latvia?", max_output_tokens=20) -print(result) # 'Riga' +print(result.content) # 'Riga' ``` #### Vision @@ -85,7 +85,7 @@ prompt = [ # Call the model to generate a response response = model(prompt, max_output_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Chat @@ -125,7 +125,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_output_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Streaming @@ -146,7 +146,7 @@ model = outlines.from_gemini( # Stream text for chunk in model.stream("Write a short story about a cat.", max_output_tokens=20): - print(chunk) # 'In...' + print(chunk.content) # 'In...' ``` ## Structured Generation @@ -169,7 +169,7 @@ model = outlines.from_gemini(genai.Client(), "gemini-1.5-flash-latest") # Call it with the ouput type to generate structured text result = model("Pizza or burger?", PizzaOrBurger, max_output_tokens=20) -print(result) # 'pizza' +print(result.content) # 'pizza' ``` #### JSON Schema @@ -196,8 +196,8 @@ model = outlines.from_gemini(genai.Client(), "gemini-1.5-flash-latest") # Call it with the ouput type to generate structured text result = model("Create a character", Character) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` #### Lists of Structured Objects @@ -220,7 +220,39 @@ model = outlines.from_gemini(genai.Client(), "gemini-1.5-flash-latest") # Call it with the ouput type to generate structured text result = model("Create a character", list[Character]) -print(result) # '[{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}, {["name":...' +print(result.content) # '[{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}, {["name":...' +``` + +## Tools + +Gemini supports tool calling. To use it, provide a list of `tools` to the model. + +For instance: + +```python +from google import genai +from outlines import from_gemini +from outlines.inputs import Chat +from typing import Optional + +# Our tool +def get_weather(city: str, hour: Optional[int] = None): + """Give the weather for a given city, and optionally for a specific hour of the day""" + return "20 degrees" + +# Create the model +model = outlines.from_gemini( + genai.Client(), + "gemini-1.5-flash-latest" +) + +# Call the model with the tool defined above +chat = Chat([ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather in Tokyo?"}, +]) +response = model(chat, tools=[get_weather], max_tokens=100) +print(response.tool_calls) # [ToolCallOutput(name='get_weather', id='toolu_01WDUo65vCXkrmjD3Yehbc5v', args={'city': 'Tokyo'})] ``` !!! Attention diff --git a/docs/features/models/index.md b/docs/features/models/index.md index 8049107e7..ce1d010d1 100644 --- a/docs/features/models/index.md +++ b/docs/features/models/index.md @@ -30,16 +30,16 @@ model = from_transformers( # Call it directly response = model("How many countries are there in the world", max_new_tokens=20) -print(response) # 'There are 200 countries in the world.' +print(response.content) # 'There are 200 countries in the world.' # Call it directly with an output_type response = model("How many countries are there in the world", int, max_new_tokens=20) -print(response) # '200' +print(response.content) # '200' # Create a generator first and then call it generator = Generator(model, int) response = generator("How many countries are there in the world") -print(response) # '200' +print(response.content) # '200' ``` Some models support streaming through a `stream` method. It takes the same argument as the `__call__` method, but returns an iterator instead of a string. @@ -58,7 +58,7 @@ model = from_openai( # Stream the response for chunk in model.stream("Tell a short story about a cat.", max_tokens=50): - print(chunk) # 'This...' + print(chunk.content) # 'This...' ``` Additionally, some models support batch processing through a `batch` method. It's similar to the `__call__` method, but takes a list of prompts instead of a single prompt and returns a list of strings. @@ -77,7 +77,7 @@ model = from_transformers( # Call it directly response = model.batch(["What's the capital of Latvia?", "What's the capital of Estonia?"], max_new_tokens=20) -print(response) # ['Riga', 'Tallinn'] +print(response.content) # ['Riga', 'Tallinn'] ``` ## Features Matrix diff --git a/docs/features/models/llamacpp.md b/docs/features/models/llamacpp.md index 1dd2c31f2..f032ce38d 100644 --- a/docs/features/models/llamacpp.md +++ b/docs/features/models/llamacpp.md @@ -51,7 +51,7 @@ model = outlines.from_llamacpp( # Call it to generate text result = model("What's the capital of Latvia?", max_tokens=20) -print(result) # 'Riga' +print(result.content) # 'Riga' ``` #### Chat @@ -81,7 +81,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'Riga.' +print(response.content) # 'Riga.' ``` #### Streaming @@ -104,7 +104,7 @@ model = outlines.from_llamacpp( # Stream text for chunk in model.stream("Write a short story about a cat.", max_tokens=100): - print(chunk) # 'In...' + print(chunk.content) # 'In...' ``` ## Structured Generation @@ -127,7 +127,7 @@ model = outlines.from_llamacpp( ) result = model("How many countries are there in the world?", output_type) -print(result) # '200' +print(result.content) # '200' ``` ### JSON Schema @@ -151,8 +151,8 @@ model = outlines.from_llamacpp( ) result = model("Create a character.", output_type=Character, max_tokens=200) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` ### Multiple Choice @@ -172,7 +172,7 @@ model = outlines.from_llamacpp( ) result = model("What is the capital of France?", output_type) -print(result) # 'Paris' +print(result.content) # 'Paris' ``` ### Regex @@ -192,7 +192,7 @@ model = outlines.from_llamacpp( ) result = model("Generate a fake social security number.", output_type) -print(result) # '782-32-3789' +print(result.content) # '782-32-3789' ``` ### Context-free grammar @@ -215,7 +215,7 @@ model = outlines.from_llamacpp( ) result = model("Are you feeling good today?", output_type) -print(result) # 'yes' +print(result.content) # 'yes' ``` ## Inference Arguments diff --git a/docs/features/models/mlxlm.md b/docs/features/models/mlxlm.md index 6ccf7f38a..d3808cf31 100644 --- a/docs/features/models/mlxlm.md +++ b/docs/features/models/mlxlm.md @@ -50,7 +50,7 @@ model = outlines.from_mlxlm( # Call it to generate text result = model("What's the capital of Latvia?", max_tokens=20) -print(result) # 'Riga' +print(result.content) # 'Riga' ``` #### Chat @@ -77,7 +77,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'Riga.' +print(response.content) # 'Riga.' ``` #### Streaming @@ -95,7 +95,7 @@ model = outlines.from_mlxlm( # Stream text for chunk in model.stream("Write a short story about a cat.", max_tokens=100): - print(chunk) # 'In...' + print(chunk.content) # 'In...' ``` ## Structured Generation @@ -115,7 +115,7 @@ model = outlines.from_mlxlm( ) result = model("How many countries are there in the world?", output_type) -print(result) # '200' +print(result.content) # '200' ``` #### JSON Schema @@ -136,8 +136,8 @@ model = outlines.from_mlxlm( ) result = model("Create a character.", output_type=Character) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` #### Multiple Choice @@ -154,7 +154,7 @@ model = outlines.from_mlxlm( ) result = model("What is the capital of France?", output_type) -print(result) # 'Paris' +print(result.content) # 'Paris' ``` #### Regex @@ -171,7 +171,7 @@ model = outlines.from_mlxlm( ) result = model("Generate a fake social security number.", output_type) -print(result) # '782-32-3789' +print(result.content) # '782-32-3789' ``` #### Context-Free Grammar @@ -208,7 +208,7 @@ model = outlines.from_mlxlm( ) result = model("Write an addition.", output_type, max_tokens=20) -print(result) # '23 + 48' +print(result.content) # '23 + 48' ``` ## Inference Arguments diff --git a/docs/features/models/ollama.md b/docs/features/models/ollama.md index 202927901..573a52e4f 100644 --- a/docs/features/models/ollama.md +++ b/docs/features/models/ollama.md @@ -62,7 +62,7 @@ model = outlines.from_ollama(ollama.Client(), "qwen2.5vl:3b") # Call it to generate text response = model("What's the capital of Latvia?") -print(response) # 'Riga' +print(response.content) # 'Riga' ``` #### Vision @@ -96,7 +96,7 @@ prompt = [ # Generate text response = model(prompt) -print(response) # The image shows a black puppy with a curious and attentive expression. +print(response.content) # The image shows a black puppy with a curious and attentive expression. ``` #### Chat @@ -135,7 +135,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Streaming @@ -151,7 +151,7 @@ model = outlines.from_ollama(ollama.Client(), "qwen2.5vl:3b") # Stream text for chunk in model.stream("Write a short story about a cat"): - print(chunk) # 'In...' + print(chunk.content) # 'In...' ``` ## Asynchronous Calls @@ -171,7 +171,7 @@ async def generate_text(): async_model = outlines.from_ollama(async_client, "qwen2.5vl:3b") result = await async_model("Write a haiku about Python.") - print(result) + print(result.content) asyncio.run(generate_text()) ``` @@ -190,7 +190,7 @@ async def stream_text(): async_model = outlines.from_ollama(async_client, "qwen2.5vl:3b") async for chunk in async_model.stream("Tell me a story about a robot."): - print(chunk, end="") + print(chunk.content, end="") asyncio.run(stream_text()) ``` @@ -219,7 +219,7 @@ async def generate_multiple(): results = await asyncio.gather(*tasks) for prompt, result in zip(prompts, results): - print(f"{prompt}\n{result}\n") + print(f"{prompt}\n{result.content}\n") asyncio.run(generate_multiple()) ``` @@ -246,8 +246,8 @@ model = outlines.from_ollama(ollama.Client(), "tinyllama") # Call it with the output type to generate structured text result = model("Create a character", Character) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` ## Inference arguments diff --git a/docs/features/models/openai.md b/docs/features/models/openai.md index 07a7fa08f..9768f5349 100644 --- a/docs/features/models/openai.md +++ b/docs/features/models/openai.md @@ -58,7 +58,7 @@ model = outlines.from_openai( # Call it to generate text response = model("What's the capital of Latvia?", max_tokens=20) -print(response) # 'Riga' +print(response.content) # 'Riga' ``` #### Vision @@ -94,7 +94,7 @@ prompt = [ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Chat @@ -133,7 +133,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Streaming @@ -154,7 +154,7 @@ model = outlines.from_openai( # Stream the response for chunk in model.stream("Tell me a short story about a cat.", max_tokens=50): - print(chunk) # 'Once...' + print(chunk.content) # 'Once...' ``` ## Structured Generation @@ -179,8 +179,8 @@ model = outlines.from_openai(openai.OpenAI(), "gpt-4o") # Call it with the output type to generate structured text result = model("Create a character, use the json format.", Character, top_p=0.1) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` #### JSON Syntax @@ -196,7 +196,7 @@ model = outlines.from_openai(openai.OpenAI(), "gpt-4o") # Call it with the output type to generate structured text result = model("Create a character, use the json format.", dict, temperature=0.5) -print(result) # '{"first_name": "Henri", "last_name": "Smith", "height": "170"}' +print(result.content) # '{"first_name": "Henri", "last_name": "Smith", "height": "170"}' ``` ## Asynchronous Calls @@ -226,20 +226,52 @@ model = outlines.from_openai( async def text_generation(): # Regular generation response = await model("What's the capital of Latvia?", max_tokens=20) - print(response) # 'Riga' + print(response.content) # 'Riga' # Streaming async for chunk in model.stream("Tell me a short story about a cat.", max_tokens=50): - print(chunk, end="") # 'Once...' + print(chunk.content, end="") # 'Once...' # Structured generation result = await model("Create a character, use the json format.", Character, top_p=0.1) - print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' - print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] + print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' + print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] asyncio.run(text_generation()) ``` +#### Tools + +Anthropic supports tool calling. To use it, provide a list of `tools` to the model. + +For instance: + +```python +from openai import OpenAI +from outlines import from_openai +from outlines.inputs import Chat +from typing import Optional + +# Our tool +def get_weather(city: str, hour: Optional[int] = None): + """Give the weather for a given city, and optionally for a specific hour of the day""" + return "20 degrees" + +# Create the model +model = from_openai( + OpenAI(), + "gpt-4o" +) + +# Call the model with the tool defined above +chat = Chat([ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather in Tokyo?"}, +]) +response = model(chat, tools=[get_weather], max_tokens=100) +print(response.tool_calls) # [ToolCallOutput(name='get_weather', id='toolu_01WDUo65vCXkrmjD3Yehbc5v', args={'city': 'Tokyo'})] +``` + ## Inference arguments When calling the model, you can provide keyword arguments that will be passed down to the `chat.completions.create` method of the OpenAI client. Some of the most common arguments include `max_tokens`, `temperature`, `stop` and `top_p`. diff --git a/docs/features/models/sglang.md b/docs/features/models/sglang.md index 428b6f69c..376e67d74 100644 --- a/docs/features/models/sglang.md +++ b/docs/features/models/sglang.md @@ -63,7 +63,7 @@ model = outlines.from_openai(openai.OpenAI(base_url="http://localhost:11434")) # Call it to generate text response = model("What's the capital of Latvia?", max_tokens=20) -print(response) # 'Riga' +print(response.content) # 'Riga' ``` #### Vision @@ -96,7 +96,7 @@ prompt = [ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Chat @@ -132,7 +132,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Streaming @@ -150,7 +150,7 @@ model = outlines.from_openai(openai.OpenAI(base_url="http://localhost:11434")) # Stream the response for chunk in model.stream("Tell me a short story about a cat.", max_tokens=50): - print(chunk) # 'Once...' + print(chunk.content) # 'Once...' ``` ## Structured Generation @@ -169,7 +169,7 @@ openai_client = openai.OpenAI(base_url="http://localhost:11434") model = outlines.from_sglang(openai_client) result = model("How many countries are there in the world?", output_type) -print(result) # '200' +print(result.content) # '200' ``` ### JSON Schema @@ -188,8 +188,8 @@ openai_client = openai.OpenAI(base_url="http://localhost:11434") model = outlines.from_sglang(openai_client) result = model("Create a character.", Character, frequency_penalty=1.5) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` ### Multiple Choice @@ -205,7 +205,7 @@ openai_client = openai.OpenAI(base_url="http://localhost:11434") model = outlines.from_sglang(openai_client) result = model("What is the capital of France?", output_type, temperature=0) -print(result) # 'Paris' +print(result.content) # 'Paris' ``` ### Regex @@ -221,7 +221,7 @@ openai_client = openai.OpenAI(base_url="http://localhost:11434") model = outlines.from_sglang(openai_client) result = model("Generate a fake social security number.", output_type, top_p=0.1) -print(result) # '782-32-3789' +print(result.content) # '782-32-3789' ``` ### Context-Free Grammar @@ -243,7 +243,7 @@ openai_client = openai.OpenAI(base_url="http://localhost:11434") model = outlines.from_sglang(openai_client) result = model("Is the weather good today?", output_type) -print(result) # 'yes' +print(result.content) # 'yes' ``` ### Async Structured Generation @@ -267,7 +267,7 @@ async def generate_user(): async_model = outlines.from_sglang(async_client) result = await async_model("Generate a random user profile.", output_type=User) - user = User.model_validate_json(result) + user = User.model_validate_json(result.content) print(f"Name: {user.name}, Email: {user.email}, Age: {user.age}") asyncio.run(generate_user()) diff --git a/docs/features/models/tgi.md b/docs/features/models/tgi.md index f71e8dad6..78e00df6a 100644 --- a/docs/features/models/tgi.md +++ b/docs/features/models/tgi.md @@ -61,7 +61,7 @@ model = outlines.from_tgi(client) # Call it to generate text result = model("Write a short story about a cat.", stop_sequences=["."]) -print(result) # 'In a quiet village where the cobblestones hummed softly beneath the morning mist...' +print(result.content) # 'In a quiet village where the cobblestones hummed softly beneath the morning mist...' ``` The `TGI` model supports streaming. For instance: @@ -76,7 +76,7 @@ model = outlines.from_tgi(client) # Stream text for chunk in model.stream("Write a short story about a cat.", stop_sequences=["."]): - print(chunk) # 'In ...' + print(chunk.content) # 'In ...' ``` ## Asynchronous Calls @@ -96,7 +96,7 @@ async def generate_text(): async_model = outlines.from_tgi(async_client) result = await async_model("Write a haiku about Python.", max_new_tokens=50) - print(result) + print(result.content) asyncio.run(generate_text()) ``` @@ -115,7 +115,7 @@ async def stream_text(): async_model = outlines.from_tgi(async_client) async for chunk in async_model.stream("Tell me a story about a robot.", max_new_tokens=100): - print(chunk, end="") + print(chunk.content, end="") asyncio.run(stream_text()) ``` @@ -144,7 +144,7 @@ async def generate_multiple(): results = await asyncio.gather(*tasks) for prompt, result in zip(prompts, results): - print(f"{prompt}\n{result}\n") + print(f"{prompt}\n{result.content}\n") asyncio.run(generate_multiple()) ``` @@ -165,7 +165,7 @@ tgi_client = huggingface_hub.InferenceClient("http://localhost:8080") model = outlines.from_tgi(tgi_client) result = model("How many countries are there in the world?", output_type) -print(result) # '200' +print(result.content) # '200' ```### JSON Schema ```python @@ -183,8 +183,8 @@ tgi_client = huggingface_hub.InferenceClient("http://localhost:8080") model = outlines.from_tgi(tgi_client) result = model("Create a character.", output_type=Character, frequency_penalty=1.5) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ```### Multiple Choice ```python @@ -198,7 +198,7 @@ tgi_client = huggingface_hub.InferenceClient("http://localhost:8080") model = outlines.from_tgi(tgi_client) result = model("What is the capital of France?", output_type, temperature=0) -print(result) # 'Paris' +print(result.content) # 'Paris' ```### Regex ```python @@ -212,7 +212,7 @@ tgi_client = huggingface_hub.InferenceClient("http://localhost:8080") model = outlines.from_tgi(tgi_client) result = model("Generate a fake social security number.", output_type, top_p=0.1) -print(result) # '782-32-3789' +print(result.content) # '782-32-3789' ``` ### Async Structured Generation @@ -235,7 +235,7 @@ async def generate_user(): async_model = outlines.from_tgi(async_client) result = await async_model("Generate a random user profile.", output_type=User) - user = User.model_validate_json(result) + user = User.model_validate_json(result.content) print(f"Name: {user.name}, Email: {user.email}, Age: {user.age}") asyncio.run(generate_user()) diff --git a/docs/features/models/transformers.md b/docs/features/models/transformers.md index 77c5d4241..081c4e4c7 100644 --- a/docs/features/models/transformers.md +++ b/docs/features/models/transformers.md @@ -51,7 +51,7 @@ model = outlines.from_transformers( # Call it to generate text result = model("What's the capital of Latvia?", max_new_tokens=20) -print(result) # 'Riga' +print(result.content) # 'Riga' ``` #### Chat @@ -79,7 +79,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_new_tokens=50) -print(response) # 'This is a picture of a black dog.' +print(response.content) # 'This is a picture of a black dog.' ``` #### Batching @@ -107,7 +107,7 @@ prompts = [ # Call it to generate text result = model.batch(prompts, max_new_tokens=20) -print(result) # ['Vilnius', 'Riga', 'Tallinn'] +print(result.content) # ['Vilnius', 'Riga', 'Tallinn'] ``` ## Structured Generation @@ -128,7 +128,7 @@ model = outlines.from_transformers( ) result = model("How many countries are there in the world?", output_type, max_new_tokens=5) -print(result) # '200' +print(result.content) # '200' ``` ### JSON Schema @@ -150,8 +150,8 @@ model = outlines.from_transformers( ) result = model("Create a character.", output_type=Character, max_new_tokens=200, repetition_penalty=0.5) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` ### Multiple Choice @@ -169,7 +169,7 @@ model = outlines.from_transformers( ) result = model("What is the capital of France?", output_type, max_new_tokens=10, temperature=0) -print(result) # 'Paris' +print(result.content) # 'Paris' ``` ### Regex @@ -187,7 +187,7 @@ model = outlines.from_transformers( ) result = model("Generate a fake social security number.", output_type, max_new_tokens=20, top_p=0.5) -print(result) # '782-32-3789' +print(result.content) # '782-32-3789' ``` ### Context-Free Grammar @@ -225,7 +225,7 @@ model = outlines.from_transformers( ) result = model("Write an addition.", output_type, max_new_tokens=100) -print(result) # '23 + 48' +print(result.content) # '23 + 48' ``` ## Inference Arguments diff --git a/docs/features/models/transformers_multimodal.md b/docs/features/models/transformers_multimodal.md index 24d119de3..f73e7fdd0 100644 --- a/docs/features/models/transformers_multimodal.md +++ b/docs/features/models/transformers_multimodal.md @@ -73,8 +73,8 @@ result = model( Animal, max_new_tokens=100 ) -print(result) # '{"specie": "cat", "color": "white", "weight": 4}' -print(Animal.model_validate_json(result)) # specie=cat, color=white, weight=4 +print(result.content) # '{"specie": "cat", "color": "white", "weight": 4}' +print(Animal.model_validate_json(result.content)) # specie=cat, color=white, weight=4 ``` The `TransformersMultiModal` model supports batch generation. To use it, invoke the `batch` method with a list of lists. You will receive as a result a list of completions. @@ -117,7 +117,7 @@ result = model.batch( ["Describe the image.", Image(get_image_from_url(IMAGE_URL_2))], ] ) -print(result) # ['The image shows a cat', 'The image shows an astronaut'] +print(result.content) # ['The image shows a cat', 'The image shows an astronaut'] ``` !!! Warning diff --git a/docs/features/models/vllm.md b/docs/features/models/vllm.md index 671354b3d..2711e2d6e 100644 --- a/docs/features/models/vllm.md +++ b/docs/features/models/vllm.md @@ -62,7 +62,7 @@ model = outlines.from_vllm(openai.OpenAI(base_url="http://0.0.0.0:8000/v1", api_ # Call it to generate text response = model("What's the capital of Latvia?", max_tokens=20) -print(response) # 'The capital of Latvia is Riga.' +print(response.content) # 'The capital of Latvia is Riga.' ``` #### Vision @@ -98,7 +98,7 @@ prompt = [ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'The image shows a black puppy lying on a wooden surface...' +print(response.content) # 'The image shows a black puppy lying on a wooden surface...' ``` #### Chat @@ -137,7 +137,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'The image shows a black puppy lying on a wooden surface...' +print(response.content) # 'The image shows a black puppy lying on a wooden surface...' ``` #### Streaming @@ -158,8 +158,7 @@ model = outlines.from_vllm( # Stream the response for chunk in model.stream("Tell me a short story about a cat.", max_tokens=50): - print(chunk, end="") # 'Once upon a time...' -print() + print(chunk.content, end="") # 'Once upon a time...' ``` ## Asynchronous Calls @@ -178,7 +177,7 @@ async def generate_text(): async_model = outlines.from_vllm(async_client, "microsoft/Phi-3-mini-4k-instruct") result = await async_model("Write a haiku about Python.", max_tokens=50) - print(result) + print(result.content) asyncio.run(generate_text()) ``` @@ -197,7 +196,7 @@ async def stream_text(): async_model = outlines.from_vllm(async_client, "microsoft/Phi-3-mini-4k-instruct") async for chunk in async_model.stream("Tell me a story about a robot.", max_tokens=100): - print(chunk, end="") + print(chunk.content, end="") asyncio.run(stream_text()) ``` @@ -225,7 +224,7 @@ async def generate_multiple(): results = await asyncio.gather(*tasks) for prompt, result in zip(prompts, results): - print(f"{prompt}\n{result}\n") + print(f"{prompt}\n{result.content}\n") asyncio.run(generate_multiple()) ``` @@ -246,7 +245,7 @@ openai_client = openai.OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="token- model = outlines.from_vllm(openai_client, "microsoft/Phi-3-mini-4k-instruct") result = model("How many countries are there in the world?", output_type) -print(result) # '200' +print(result.content) # '200' ``` ### JSON Schema @@ -266,8 +265,8 @@ openai_client = openai.OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="token- model = outlines.from_vllm(openai_client, "microsoft/Phi-3-mini-4k-instruct") result = model("Create a character.", output_type=Character, frequency_penalty=1.5) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` ### Multiple Choice @@ -283,7 +282,7 @@ openai_client = openai.OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="token- model = outlines.from_vllm(openai_client, "microsoft/Phi-3-mini-4k-instruct") result = model("What is the capital of France?", output_type, temperature=0) -print(result) # 'Paris' +print(result.content) # 'Paris' ``` ### Regex @@ -299,7 +298,7 @@ openai_client = openai.OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="token- model = outlines.from_vllm(openai_client, "microsoft/Phi-3-mini-4k-instruct") result = model("Generate a fake social security number.", output_type, top_p=0.1) -print(result) # '782-32-3789' +print(result.content) # '782-32-3789' ``` ### Context-Free Grammar @@ -335,7 +334,7 @@ openai_client = openai.OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="token- model = outlines.from_vllm(openai_client, "microsoft/Phi-3-mini-4k-instruct") result = model("Write an addition.", output_type, extra_body={"guided_decoding_backend": "outlines"}) -print(result) # '23 + 48' +print(result.content) # '23 + 48' ``` ### Async Structured Generation @@ -358,7 +357,7 @@ async def generate_user(): async_model = outlines.from_vllm(async_client, "microsoft/Phi-3-mini-4k-instruct") result = await async_model("Generate a random user profile.", output_type=User) - user = User.model_validate_json(result) + user = User.model_validate_json(result.content) print(f"Name: {user.name}, Email: {user.email}, Age: {user.age}") asyncio.run(generate_user()) diff --git a/docs/features/models/vllm_offline.md b/docs/features/models/vllm_offline.md index 91b454d68..b7e848f06 100644 --- a/docs/features/models/vllm_offline.md +++ b/docs/features/models/vllm_offline.md @@ -53,7 +53,7 @@ model = outlines.from_vllm_offline( # Call it to generate text response = model("What's the capital of Latvia?", max_tokens=20) -print(response) # 'Riga' +print(response.content) # 'Riga' ``` #### Chat @@ -80,7 +80,7 @@ prompt = Chat([ # Call the model to generate a response response = model(prompt, max_tokens=50) -print(response) # 'Riga' +print(response.content) # 'Riga' ``` #### Streaming @@ -100,7 +100,7 @@ model = outlines.from_vllm_offline( # Stream the response for chunk in model.stream("Tell me a short story about a cat.", max_tokens=50): - print(chunk) # 'Once...' + print(chunk.content) # 'Once...' ``` #### Batching @@ -127,7 +127,7 @@ prompts = [ # Call it to generate text result = model.batch(prompts, max_new_tokens=20) -print(result) # ['Vilnius', 'Riga', 'Tallinn'] +print(result.content) # ['Vilnius', 'Riga', 'Tallinn'] ``` ## Structured Generation @@ -147,7 +147,7 @@ model = outlines.from_vllm_offline( ) result = model("How many countries are there in the world?", output_type) -print(result) # '200' +print(result.content) # '200' ``` ### JSON Schema @@ -168,8 +168,8 @@ model = outlines.from_vllm_offline( ) result = model("Create a character.", output_type=Character, sampling_params=SamplingParams(frequency_penalty=1.5, max_tokens=200)) -print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' -print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] +print(result.content) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}' +print(Character.model_validate_json(result.content)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy'] ``` ### Multiple Choice @@ -186,7 +186,7 @@ model = outlines.from_vllm_offline( ) result = model("What is the capital of France?", output_type, sampling_params=SamplingParams(temperature=0)) -print(result) # 'Paris' +print(result.content) # 'Paris' ``` ### Regex @@ -203,7 +203,7 @@ model = outlines.from_vllm_offline( ) result = model("Generate a fake social security number.", output_type, sampling_params=SamplingParams(top_p=0.1)) -print(result) # '782-32-3789' +print(result.content) # '782-32-3789' ``` ### Context-Free Grammar @@ -240,7 +240,7 @@ model = outlines.from_vllm_offline( ) result = model("Write an addition.", output_type) -print(result) # '23 + 48' +print(result.content) # '23 + 48' ``` ## Inference Arguments diff --git a/docs/guide/fastapi_vllm_deployment.md b/docs/guide/fastapi_vllm_deployment.md index a5660aeb6..a43a45373 100644 --- a/docs/guide/fastapi_vllm_deployment.md +++ b/docs/guide/fastapi_vllm_deployment.md @@ -146,7 +146,7 @@ async def analyze_ticket(request: TicketRequest): try: # Generate and parse a structured response result = await async_model(prompt, TicketAnalysis, max_tokens=5000) - analysis = TicketAnalysis.model_validate_json(result) + analysis = TicketAnalysis.model_validate_json(result.content) return analysis @@ -174,7 +174,7 @@ async def generate_response( try: # Generate and parse a structured response result = await async_model(prompt, SupportResponse, max_tokens=5000) - response = SupportResponse.model_validate_json(result) + response = SupportResponse.model_validate_json(result.content) return response diff --git a/docs/guide/getting_started.md b/docs/guide/getting_started.md index 617e27e5c..8ecb74267 100644 --- a/docs/guide/getting_started.md +++ b/docs/guide/getting_started.md @@ -182,7 +182,7 @@ model = # Call the model to generate text result = model("Write a short story about a cat.") -print(result) # 'In a quiet village where the cobblestones hummed softly beneath the morning mist...' +print(result.content) # 'In a quiet village where the cobblestones hummed softly beneath the morning mist...' ``` Most models also support streaming through the use of a `streaming` method. You can directly use with a prompt just like regular text generation. For instance: @@ -192,7 +192,7 @@ model = # Stream text for chunk in model.streaming("Write a short story about a cat.") - print(chunk) # 'In ...' + print(chunk.content) # 'In ...' ``` ## Structured Generation @@ -218,7 +218,7 @@ In the meantime, you can find below examples of using each of the five output ty # Generate an integer result = model("How many countries are there in the world?", int) - print(result) # '200' + print(result.content) # '200' ``` === "Multiple Choice" @@ -235,7 +235,7 @@ In the meantime, you can find below examples of using each of the five output ty # Generate text corresponding to either of the choices defined above result = model("What do you want to eat, a pizza or a burger?", PizzaOrBurger) - print(result) # 'pizza' + print(result.content) # 'pizza' ``` === "JSON Schemas" @@ -255,8 +255,8 @@ In the meantime, you can find below examples of using each of the five output ty # Generate a character result = model("Create a character", Character) - print(result) # '{"name": "Aurora", "birth_date": "1990-06-15", "skills": ["Stealth", "Diplomacy"]}' - print(Character.model_validate_json(result)) # name=Aurora birth_date=datetime.date(1990, 6, 15) skills=['Stealth', 'Diplomacy'] + print(result.content) # '{"name": "Aurora", "birth_date": "1990-06-15", "skills": ["Stealth", "Diplomacy"]}' + print(Character.model_validate_json(result.content)) # name=Aurora birth_date=datetime.date(1990, 6, 15) skills=['Stealth', 'Diplomacy'] ``` === "Regex" @@ -271,7 +271,7 @@ In the meantime, you can find below examples of using each of the five output ty # Generate the number result = model("Write a 3 digit number", output_type) - print(result) # '236' + print(result.content) # '236' ``` === "Context-free Grammars" @@ -305,7 +305,7 @@ In the meantime, you can find below examples of using each of the five output ty # Generate an arithmetic operation result = model("Write an arithmetic operation", CFG(grammar_string)) - print(result) # '2 + 3' + print(result.content) # '2 + 3' ``` It's important to note that not all output types are available for all models due to limitations in the underlying inference engines. The [Models](../features/models/index.md) section of the features documentation includes a features matrix that summarize the availability of output types. @@ -329,7 +329,7 @@ generator = Generator(model, Literal["pizza", "burger"]) # Call it as you would call a model result = generator("What do you want to eat, a pizza or a burger?") -print(result) # pizza +print(result.content) # pizza ``` You can find more information on generators in the dedicated page on [Generators](../features/core/generator.md) in the features documentation. diff --git a/docs/guide/vlm.md b/docs/guide/vlm.md index 979e65924..ced73231d 100644 --- a/docs/guide/vlm.md +++ b/docs/guide/vlm.md @@ -157,7 +157,7 @@ result = image_data_generator({ "text": pixtral_instruction, "images": image }) -print(result) +print(result.content) ``` This code loads an image from a URL, passes it to our vision multi-modal model along with the instruction prompt, and generates a structured output based on the defined schema. We end up with an output like this, ready to be used for the next stage in your pipeline: diff --git a/examples/babyagi.py b/examples/babyagi.py index 4a817906d..a7f358f36 100644 --- a/examples/babyagi.py +++ b/examples/babyagi.py @@ -78,7 +78,7 @@ def one_cycle(objective: str, task_list, next_task_id: int): task = task_list.popleft() prompt = perform_task_ppt(objective=objective, task=task) - result = complete(prompt) + result = complete(prompt).content prompt = create_tasks_ppt( objective=objective, @@ -86,7 +86,7 @@ def one_cycle(objective: str, task_list, next_task_id: int): result=result, previous_tasks=[first_task["task_name"]], ) - new_tasks = complete(prompt) + new_tasks = complete(prompt).content new_tasks = create_tasks_fmt(new_tasks) @@ -99,7 +99,7 @@ def one_cycle(objective: str, task_list, next_task_id: int): tasks=[task["task_name"] for task in task_list], next_task_id=next_task_id, ) - prioritized_tasks = complete(prompt) + prioritized_tasks = complete(prompt).content prioritized_tasks = prioritize_tasks_fmt(prioritized_tasks) diff --git a/examples/beam-cloud/app.py b/examples/beam-cloud/app.py index 781bc1d3d..7bc96b42d 100644 --- a/examples/beam-cloud/app.py +++ b/examples/beam-cloud/app.py @@ -42,5 +42,5 @@ def predict(context, **inputs): model = context.on_start_value # Inference generator = outlines.Generator(model, Literal["Positive", "Negative"]) - answer = generator(prompt) + answer = generator(prompt).content return {"answer": answer} diff --git a/examples/bentoml/service.py b/examples/bentoml/service.py index 370bac901..a6be04e4b 100644 --- a/examples/bentoml/service.py +++ b/examples/bentoml/service.py @@ -77,6 +77,6 @@ async def generate( import outlines generator = outlines.Generator(self.model, outlines.json_schema(json_schema)) - character = generator(prompt) + character = generator(prompt).content return character diff --git a/examples/cerebrium/main.py b/examples/cerebrium/main.py index 4e1987197..fe6f60859 100644 --- a/examples/cerebrium/main.py +++ b/examples/cerebrium/main.py @@ -43,7 +43,7 @@ def generate( character = model( f"[INST]Give me a character description. Describe {prompt}.[/INST]", outlines.json_schema(schema), - ) + ).content print(character) return character diff --git a/examples/dating_profile.py b/examples/dating_profile.py index ff488698e..1869531db 100644 --- a/examples/dating_profile.py +++ b/examples/dating_profile.py @@ -96,7 +96,7 @@ class Example: dating_profile_prompt = Template.from_file("prompts/dating_profile.txt") prompt = dating_profile_prompt(description=new_description, examples=samples) -profile = model(prompt, outlines.json_schema(DatingProfile), max_tokens=500) # type: ignore +profile = model(prompt, outlines.json_schema(DatingProfile), max_tokens=500).content # type: ignore print(profile) # Sample generated profiles diff --git a/examples/llamacpp_example.py b/examples/llamacpp_example.py index f6afc40ed..93d4f7f22 100644 --- a/examples/llamacpp_example.py +++ b/examples/llamacpp_example.py @@ -42,4 +42,4 @@ class Character(BaseModel): prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:" sequence = generator(prompt, seed=seed, max_tokens=512) - print(sequence) + print(sequence.content) diff --git a/examples/math_generate_code.py b/examples/math_generate_code.py index a5a53f084..b7378b40a 100644 --- a/examples/math_generate_code.py +++ b/examples/math_generate_code.py @@ -39,6 +39,6 @@ def execute_code(code): prompt = answer_with_code_prompt(question=question, examples=examples) model = outlines.from_openai(openai.OpenAI(), "gpt-4o-mini") -answer = model(prompt) +answer = model(prompt).content result = execute_code(answer) print(f"It takes Carla {result:.0f} minutes to download the file.") diff --git a/examples/meta_prompting.py b/examples/meta_prompting.py index ed4b7455a..53aaa8b46 100644 --- a/examples/meta_prompting.py +++ b/examples/meta_prompting.py @@ -31,12 +31,12 @@ def split_into_steps(question, model_name: str): model = outlines.from_openai(client, model_name) prompt = solve(question=question) - answer = model(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500).content prompt += ( answer + "\n what is the only option that displays the same type of relationship as : :?" ) - answer = model(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500).content completed = prompt + answer return completed @@ -55,9 +55,9 @@ def fill_in_the_blanks(question, model_name: str): model = outlines.from_openai(client, model_name) prompt = determine_goal(question=question) - answer = model(prompt, stop=["."]) + answer = model(prompt, stop=["."]).content prompt = solve(memory=prompt + answer) - answer = model(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500).content completed = prompt + answer return completed @@ -94,9 +94,9 @@ def ask_an_expert(question, model_name: str): model = outlines.from_openai(client, model_name) prompt = find_expert(question=question) - expert = model(prompt, stop=['"']) + expert = model(prompt, stop=['"']).content prompt = get_answer(question=question, expert=expert, memory=prompt+expert) - answer = model(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500).content completed = prompt + answer return completed @@ -121,9 +121,9 @@ def ask_an_expert_simple(question, model_name: str): model = outlines.from_openai(client, model_name) prompt = find_expert(question=question) - expert = model(prompt, stop=["\n", "."]) + expert = model(prompt, stop=["\n", "."]).content prompt = get_answer(expert=expert, memory=prompt+expert) - answer = model(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500).content completed = prompt + answer return completed diff --git a/examples/modal_example.py b/examples/modal_example.py index fe0e6bc67..e7f2b5183 100644 --- a/examples/modal_example.py +++ b/examples/modal_example.py @@ -76,7 +76,7 @@ def generate( character = model( f"[INST]Give me a character description. Describe {prompt}.[/INST]", outlines.json_schema(schema), - ) + ).content print(character) diff --git a/examples/pick_odd_one_out.py b/examples/pick_odd_one_out.py index 36ec9dfe7..a71520655 100644 --- a/examples/pick_odd_one_out.py +++ b/examples/pick_odd_one_out.py @@ -39,8 +39,8 @@ prompt = build_ooo_prompt(options=options) reasoning = gen_text(prompt, stop=["Pick the odd word", "So the odd one"]) -prompt += reasoning +prompt += reasoning.content raw_result = gen_choice(prompt) -result = json.loads(raw_result)["result"] +result = json.loads(raw_result.content)["result"] prompt += result print(result) diff --git a/examples/react.py b/examples/react.py index b7c807726..6abafa864 100644 --- a/examples/react.py +++ b/examples/react.py @@ -80,20 +80,20 @@ def search_wikipedia(query: str): for i in range(1, 10): mode_output = mode_generator(prompt, max_tokens=128) - mode = json.loads(mode_output)["result"] # Extract the result from the JSON output + mode = json.loads(mode_output.content)["result"] # Extract the result from the JSON output prompt = add_mode(i=i, mode=mode, result="", prompt=prompt) if mode == "Tho": thought = text_generator(prompt, stop="\n", max_tokens=128) - prompt += f"{thought}" + prompt += f"{thought.content}" elif mode == "Act": action_output = action_generator(prompt, max_tokens=128) - action = json.loads(action_output)["result"] # Extract the result from the JSON output + action = json.loads(action_output.content)["result"] # Extract the result from the JSON output prompt += f"{action} '" subject = text_generator(prompt, stop=["'"], max_tokens=128) # Apple Computers headquartered - subject = " ".join(subject.split()[:2]) + subject = " ".join(subject.content.split()[:2]) prompt += f"{subject}'" if action == "Search": diff --git a/examples/self_consistency.py b/examples/self_consistency.py index 061ff594c..65cac6ce5 100644 --- a/examples/self_consistency.py +++ b/examples/self_consistency.py @@ -54,7 +54,7 @@ digits = [] for answer in answers: try: - match = re.findall(r"\d+", answer)[-1] + match = re.findall(r"\d+", answer.content)[-1] if match is not None: digit = int(match) digits.append(digit) diff --git a/examples/simulation_based_inference.ipynb b/examples/simulation_based_inference.ipynb index e1f888324..d0f169bc3 100644 --- a/examples/simulation_based_inference.ipynb +++ b/examples/simulation_based_inference.ipynb @@ -108,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "9fbebaa9-f05e-4c6b-8875-73a08273bbb5", "metadata": {}, "outputs": [], @@ -137,7 +137,7 @@ " samples = []\n", " for answer_raw in answers_raw:\n", " try:\n", - " answer = re.findall(r\"\\d+\", answer_raw)[-1]\n", + " answer = re.findall(r\"\\d+\", answer_raw.content)[-1]\n", " if answer == problem[\"answer\"]:\n", " samples += example_ids\n", " else:\n", diff --git a/outlines/__init__.py b/outlines/__init__.py index 007c376da..62c531176 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -4,7 +4,9 @@ from outlines import grammars as grammars from outlines import inputs as inputs from outlines import models as models +from outlines import outputs as outputs from outlines import processors as processors +from outlines import tools as tools from outlines import types as types from outlines.applications import Application as Application from outlines.caching import clear_cache as clear_cache @@ -13,10 +15,17 @@ from outlines.generator import Generator as Generator from outlines.inputs import Audio as Audio from outlines.inputs import Image as Image +from outlines.inputs import ToolCall as ToolCall +from outlines.inputs import SystemMessage as SystemMessage +from outlines.inputs import UserMessage as UserMessage +from outlines.inputs import AssistantMessage as AssistantMessage +from outlines.inputs import ToolMessage as ToolMessage +from outlines.inputs import Chat as Chat from outlines.inputs import Video as Video from outlines.models import * # noqa: F403 from outlines.templates import Template as Template from outlines.templates import Vision as Vision +from outlines.tools import ToolDef as ToolDef from outlines.types import cfg as cfg from outlines.types import json_schema as json_schema from outlines.types import regex as regex diff --git a/outlines/applications.py b/outlines/applications.py index 4b4b34187..e0d9fee4c 100644 --- a/outlines/applications.py +++ b/outlines/applications.py @@ -48,7 +48,7 @@ class OutputModel(BaseModel): application = Application(template, JsonType(OutputModel)) result = application(model, {"num": 3}, max_new_tokens=20) - print(result) # Expected output: { "result" : 6 } + print(result.content) # Expected output: '{ "result" : 6 }' ``` """ diff --git a/outlines/generator.py b/outlines/generator.py index f2e669d8f..a64319060 100644 --- a/outlines/generator.py +++ b/outlines/generator.py @@ -21,6 +21,8 @@ get_regex_logits_processor, ) from outlines.backends.base import LogitsProcessorType +from outlines.outputs import Output, StreamingOutput +from outlines.tools import ToolDef, ToolsInput, get_formatted_tools from outlines.types import CFG, JsonSchema from outlines.types.dsl import python_types_to_terms, to_regex @@ -34,8 +36,15 @@ class BlackBoxGenerator: """ output_type: Optional[Any] + tools: Optional[List[ToolDef]] - def __init__(self, model: BlackBoxModel, output_type: Optional[Any]): + def __init__( + self, + model: BlackBoxModel, + output_type: Optional[Any], + *, + tools: Optional[ToolsInput] = None, + ): """ Parameters ---------- @@ -43,12 +52,18 @@ def __init__(self, model: BlackBoxModel, output_type: Optional[Any]): An instance of an Outlines model. output_type The output type that will be used to constrain the generation. + tools + A list of tools to use for the generator. Can contain a list of + ToolDef, Callable, or BaseModel instances. """ self.model = model self.output_type = output_type + self.tools = get_formatted_tools(tools) - def __call__(self, prompt: Any, **inference_kwargs) -> Any: + def __call__( + self, prompt: Any, **inference_kwargs + ) -> Output | List[Output]: """Generate a response from the model. Parameters @@ -60,15 +75,17 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any: Returns ------- - Any - The response generated by the model. + Output | List[Output] + The output generated by the model. """ return self.model.generate( - prompt, self.output_type, **inference_kwargs + prompt, self.output_type, tools=self.tools, **inference_kwargs ) - def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]: + def batch( + self, prompts: List[Any], **inference_kwargs + ) -> List[Output] | List[List[Output]]: """Generate a batch of responses from the model. Parameters @@ -80,15 +97,17 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]: Returns ------- - List[Any] - The list of responses generated by the model. + List[Output] | List[List[Output]] + The list of outputs generated by the model. """ return self.model.generate_batch( - prompts, self.output_type, **inference_kwargs + prompts, self.output_type, tools=self.tools, **inference_kwargs ) - def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]: + def stream( + self, prompt: Any, **inference_kwargs + ) -> Iterator[StreamingOutput]: """Generate a stream of responses from the model. Parameters @@ -100,12 +119,12 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]: Returns ------- - Any - The response generated by the model. + Iterator[StreamingOutput] + A stream of StreamingOutput generated by the model. """ return self.model.generate_stream( - prompt, self.output_type, **inference_kwargs + prompt, self.output_type, tools=self.tools, **inference_kwargs ) @@ -118,8 +137,15 @@ class AsyncBlackBoxGenerator: """ output_type: Optional[Any] + tools: Optional[List[ToolDef]] - def __init__(self, model: AsyncBlackBoxModel, output_type: Optional[Any]): + def __init__( + self, + model: AsyncBlackBoxModel, + output_type: Optional[Any], + *, + tools: Optional[ToolsInput] = None, + ): """ Parameters ---------- @@ -127,12 +153,18 @@ def __init__(self, model: AsyncBlackBoxModel, output_type: Optional[Any]): An instance of an Outlines model. output_type The output type that will be used to constrain the generation. + tools + A list of tools to use for the generator. Can contain a list of + ToolDef, Callable, or BaseModel instances. """ self.model = model self.output_type = output_type + self.tools = get_formatted_tools(tools) - async def __call__(self, prompt: Any, **inference_kwargs) -> Any: + async def __call__( + self, prompt: Any, **inference_kwargs + ) -> Output | List[Output]: """Generate a response from the model. Parameters @@ -144,15 +176,17 @@ async def __call__(self, prompt: Any, **inference_kwargs) -> Any: Returns ------- - Any - The response generated by the model. + Output | List[Output] + The output generated by the model. """ return await self.model.generate( - prompt, self.output_type, **inference_kwargs + prompt, self.output_type, tools=self.tools, **inference_kwargs ) - async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]: + async def batch( + self, prompts: List[Any], **inference_kwargs + ) -> List[Output] | List[List[Output]]: """Generate a batch of responses from the model. Parameters @@ -164,15 +198,17 @@ async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]: Returns ------- - List[Any] - The list of responses generated by the model. + List[Output] | List[List[Output]] + The list of outputs generated by the model. """ return await self.model.generate_batch( - prompts, self.output_type, **inference_kwargs + prompts, self.output_type, tools=self.tools, **inference_kwargs ) - async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]: + async def stream( + self, prompt: Any, **inference_kwargs + ) -> AsyncIterator[StreamingOutput]: """Generate a stream of responses from the model. Parameters @@ -184,12 +220,13 @@ async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]: Returns ------- - Any - The response generated by the model. + AsyncIterator[StreamingOutput] + A coroutine that will produce an async iterator of StreamingOutput + produced by the model. """ async for chunk in self.model.generate_stream( # pragma: no cover - prompt, self.output_type, **inference_kwargs + prompt, self.output_type, tools=self.tools, **inference_kwargs ): yield chunk @@ -212,12 +249,15 @@ class SteerableGenerator: """ logits_processor: Optional[LogitsProcessorType] + tools: Optional[List[ToolDef]] def __init__( self, model: SteerableModel, output_type: Optional[Any], backend_name: Optional[str] = None, + *, + tools: Optional[ToolsInput] = None, ): """ Parameters @@ -228,13 +268,21 @@ def __init__( The output type expressed as a Python type backend_name The name of the backend to use to create the logits processor. + tools + A list of tools to use for the generator. Can contain a list of + ToolDef, Callable, or BaseModel instances. """ self.model = model + # tools are not implemented yet for steerable models + # error is raised within each model's type_adapter + self.tools = get_formatted_tools(tools) + if output_type is None: self.logits_processor = None else: term = python_types_to_terms(output_type) + if isinstance(term, CFG): cfg_string = term.definition self.logits_processor = get_cfg_logits_processor( @@ -258,7 +306,11 @@ def __init__( @classmethod def from_processor( - cls, model: SteerableModel, processor: LogitsProcessorType + cls, + model: SteerableModel, + processor: LogitsProcessorType, + *, + tools: Optional[ToolsInput] = None, ): """Create a generator from a logits processor. @@ -270,13 +322,14 @@ def from_processor( An instance of a logits processor. """ - instance = cls.__new__(cls) - instance.model = model + instance = cls(model, None, tools=tools) instance.logits_processor = processor return instance - def __call__(self, prompt: Any, **inference_kwargs) -> Any: + def __call__( + self, prompt: Any, **inference_kwargs + ) -> Output | List[Output]: """Generate a response from the model. Parameters @@ -288,17 +341,20 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any: Returns ------- - Any - The response generated by the model. + Output | List[Output] + The output generated by the model. """ if self.logits_processor is not None: self.logits_processor.reset() + return self.model.generate( - prompt, self.logits_processor, **inference_kwargs + prompt, self.logits_processor, tools=self.tools, **inference_kwargs ) - def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]: + def batch( + self, prompts: List[Any], **inference_kwargs + ) -> List[Output] | List[List[Output]]: """Generate a batch of responses from the model. Parameters @@ -310,17 +366,20 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]: Returns ------- - List[Any] - The list of responses generated by the model. + List[Output] | List[List[Output]] + The list of outputs generated by the model. """ if self.logits_processor is not None: self.logits_processor.reset() + return self.model.generate_batch( - prompts, self.logits_processor, **inference_kwargs + prompts, self.logits_processor, tools=self.tools, **inference_kwargs ) - def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]: + def stream( + self, prompt: Any, **inference_kwargs + ) -> Iterator[StreamingOutput]: """Generate a stream of responses from the model. Parameters @@ -332,14 +391,15 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]: Returns ------- - Any - The response generated by the model. + Iterator[StreamingOutput] + A stream of StreamingOutput generated by the model. """ if self.logits_processor is not None: self.logits_processor.reset() + return self.model.generate_stream( - prompt, self.logits_processor, **inference_kwargs + prompt, self.logits_processor, tools=self.tools, **inference_kwargs ) @@ -348,6 +408,7 @@ def Generator( output_type: Optional[Any] = None, backend: Optional[str] = None, *, + tools: Optional[ToolsInput] = None, processor: Optional[LogitsProcessorType] = None, ) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]: """Create a generator for the given model and output parameters. @@ -367,6 +428,9 @@ def Generator( The name of the backend to use to create the logits processor. Only used for steerable models if there is an output type and `processor` is not provided. + tools + A list of tools to use for the generator. Can contain an MCPServer, + a list of ToolDef, Callable, or BaseModel instances. processor An instance of a logits processor. @@ -380,6 +444,7 @@ def Generator( param is not None for param in [output_type, processor] ) + if provided_output_params > 1: raise ValueError( "At most one of output_type or processor can be provided" @@ -387,18 +452,18 @@ def Generator( if isinstance(model, SteerableModel): # type: ignore if processor is not None: - return SteerableGenerator.from_processor(model, processor) # type: ignore + return SteerableGenerator.from_processor(model, processor, tools=tools) # type: ignore else: - return SteerableGenerator(model, output_type, backend) # type: ignore + return SteerableGenerator(model, output_type, backend, tools=tools) # type: ignore else: if processor is not None: raise NotImplementedError( "This model does not support logits processors" ) if isinstance(model, AsyncBlackBoxModel): # type: ignore - return AsyncBlackBoxGenerator(model, output_type) # type: ignore + return AsyncBlackBoxGenerator(model, output_type, tools=tools) # type: ignore elif isinstance(model, BlackBoxModel): # type: ignore - return BlackBoxGenerator(model, output_type) # type: ignore + return BlackBoxGenerator(model, output_type, tools=tools) # type: ignore else: raise ValueError( "The model argument must be an instance of " diff --git a/outlines/inputs.py b/outlines/inputs.py index 50a6a6741..61373a2f9 100644 --- a/outlines/inputs.py +++ b/outlines/inputs.py @@ -1,12 +1,20 @@ """Contain classes used to define the inputs of a model.""" import base64 +import sys from dataclasses import dataclass from io import BytesIO -from typing import Any, Dict, List, Optional +from typing import Any, List, Literal, Optional, Union from PIL import Image as PILImage +from outlines.outputs import Output + +if sys.version_info >= (3, 12): # pragma: no cover + from typing import TypedDict +else: # pragma: no cover + from typing_extensions import TypedDict + @dataclass class Image: @@ -71,46 +79,154 @@ class Audio: audio: Any -@dataclass +class ToolCall(TypedDict): + tool_name: str + tool_call_id: Optional[str] + args: dict[str, Any] + + +class SystemMessage(TypedDict): + role: Literal["system"] + content: str + + +class UserMessage(TypedDict): + role: Literal["user"] + content: str | List + + +class AssistantMessage(TypedDict): + role: Literal["assistant"] + content: Optional[str] + tool_calls: Optional[List[ToolCall]] + + +class ToolMessage(TypedDict): + role: Literal["tool"] + tool_name: Optional[str] + tool_call_id: Optional[str] + content: str | List + + +Message = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] + + class Chat: """Contains the input for a chat model. Provide an instance of this class as the `model_input` argument to a model that supports chat. - Each message contained in the messages list must be a dict with 'role' and - 'content' keys. The role can be 'user', 'assistant', or 'system'. The content - can be a string or a list containing a str and assets (images, videos, - audios, etc.) in the case of multimodal models. + Each element contained in the messages list must be a Message or an Output + instance. Examples -------- ```python + import transformers + import outlines + from outlines.inputs import Chat, Image + + MODEL_ID = "microsoft/Phi-3-mini-4k-instruct" + + model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained(MODEL_ID), + transformers.AutoTokenizer.from_pretrained(MODEL_ID), + ) + # Initialize the chat with a system message. chat_prompt = Chat([ {"role": "system", "content": "You are a helpful assistant."}, ]) - # Add a user message with an image and call the model (not shown here). - chat_prompt.add_user_message(["Describe the image below", Image(image)]) + # Add a user message to the chat. + chat_prompt.add_user_message("What's the capital of Latvia?") - # Add as an assistant message the response from the model. - chat_prompt.add_assistant_message("The is a black cat sitting on a couch.") - ``` + # Call the model with the chat input. + response = model(chat_prompt) + print(response.content) # 'The capital of Latvia is Riga.' - Parameters - ---------- - messages - The list of messages that will be provided to the model. + # Add the output to the chat. + chat_prompt.add_output(response) + + # Add another user message to the chat and call the model again. + chat_prompt.add_user_message("How many inhabitants does it have?") + response = model(chat_prompt) + print(response.content) # '600,000' + ``` """ - messages: List[Dict[str, Any]] = None # type: ignore + def __init__(self, messages: Optional[List[Message | Output]] = None): + """ + Parameters + ---------- + messages + The list of messages and outputs that will be provided to the + model. - def __post_init__(self): - if self.messages is None: - self.messages = [] + """ + if not messages: + messages = [] + self.messages = self._format_messages(messages) - def append(self, message: Dict[str, Any]): + def _format_messages( + self, messages: List[Message | Output] + ) -> List[Message]: + """Transform a list of messages or outputs to a list of messages. + + Parameters + ---------- + messages + The list of messages or outputs to transform. + + Returns + ------- + List[Message] + The list of messages. + + """ + return [ + self._output_to_assistant_message(message) + if isinstance(message, Output) + else message + for message in messages + ] + + def _output_to_assistant_message(self, output: Output) -> AssistantMessage: + """Transform an Output instance to an AssistantMessage instance. + + Parameters + ---------- + output + The Output instance to transform. + + Returns + ------- + AssistantMessage + The AssistantMessage instance. + + """ + if output.tool_calls: + return AssistantMessage( + role="assistant", + content=output.content, + tool_calls=[ + ToolCall( + tool_name=tool_call.name, + tool_call_id=tool_call.id, + args=tool_call.args + ) + for tool_call in output.tool_calls + ], + ) + else: + return AssistantMessage( + role="assistant", + content=output.content, + tool_calls=None, + ) + + def append(self, message: Message): """Add a message to the chat. Parameters @@ -119,9 +235,9 @@ def append(self, message: Dict[str, Any]): The message to add to the chat. """ - self.messages.append(message) + self.messages.extend(self._format_messages([message])) - def extend(self, messages: List[Dict[str, Any]]): + def extend(self, messages: List[Message | Output]): """Add a list of messages to the chat. Parameters @@ -130,9 +246,9 @@ def extend(self, messages: List[Dict[str, Any]]): The list of messages to add to the chat. """ - self.messages.extend(messages) + self.messages.extend(self._format_messages(messages)) - def pop(self) -> Dict[str, Any]: + def pop(self) -> Message: """Remove the last message from the chat. Returns @@ -143,7 +259,7 @@ def pop(self) -> Dict[str, Any]: """ return self.messages.pop() - def add_system_message(self, content: str | List[Any]): + def add_system_message(self, content: Any): """Add a system message to the chat. Parameters @@ -152,9 +268,9 @@ def add_system_message(self, content: str | List[Any]): The content of the system message. """ - self.messages.append({"role": "system", "content": content}) + self.messages.append(SystemMessage(role="system", content=content)) - def add_user_message(self, content: str | List[Any]): + def add_user_message(self, content: Any): """Add a user message to the chat. Parameters @@ -163,18 +279,69 @@ def add_user_message(self, content: str | List[Any]): The content of the user message. """ - self.messages.append({"role": "user", "content": content}) + self.messages.append(UserMessage(role="user", content=content)) - def add_assistant_message(self, content: str | List[Any]): + def add_assistant_message( + self, + content: Any, + tool_calls: Optional[List[ToolCall]] = None + ): """Add an assistant message to the chat. Parameters ---------- content The content of the assistant message. + tool_calls + The tool calls of the assistant message. + + """ + self.messages.append( + AssistantMessage( + role="assistant", + content=content, + tool_calls=tool_calls, + ) + ) + + def add_tool_message( + self, + content: str, + tool_call_id: Optional[str] = None, + tool_name: Optional[str] = None, + ): + """Add a tool message to the chat. + + Parameters + ---------- + content + The content of the tool message. + tool_call_id + The ID of the tool call. + tool_name + The name of the tool. + + """ + self.messages.append( + ToolMessage( + role="tool", + content=content, + tool_call_id=tool_call_id, + tool_name=tool_name, + ) + ) + + def add_output(self, output: Output): + """Add a generated output to the chat. The output will be converted to + an assistant message. + + Parameters + ---------- + output + The output to add to the chat. """ - self.messages.append({"role": "assistant", "content": content}) + self.messages.append(self._output_to_assistant_message(output)) def __str__(self): return "\n".join(str(message) for message in self.messages) diff --git a/outlines/models/anthropic.py b/outlines/models/anthropic.py index 4b5823952..cb03a172a 100644 --- a/outlines/models/anthropic.py +++ b/outlines/models/anthropic.py @@ -1,10 +1,12 @@ """Integration with Anthropic's API.""" from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from typing import TYPE_CHECKING, Any, Iterator, Optional, Union, List -from outlines.inputs import Chat, Image +from outlines.inputs import Chat, Image, Message, UserMessage, ToolCall from outlines.models.base import Model, ModelTypeAdapter +from outlines.outputs import Output, StreamingOutput, ToolCallOutput, StreamingToolCallOutput +from outlines.tools import ToolDef if TYPE_CHECKING: from anthropic import Anthropic as AnthropicClient @@ -47,66 +49,103 @@ def format_input(self, model_input): @format_input.register(str) def format_str_model_input(self, model_input: str) -> dict: return { - "messages": [self._create_message("user", model_input)] + "messages": [ + self._create_anthropic_message( + UserMessage( + role="user", + content=model_input + ) + ) + ] } @format_input.register(list) def format_list_model_input(self, model_input: list) -> dict: return { "messages": [ - self._create_message("user", model_input) + self._create_anthropic_message( + UserMessage( + role="user", + content=model_input + ) + ) ] } @format_input.register(Chat) def format_chat_model_input(self, model_input: Chat) -> dict: - """Generate the `messages` argument to pass to the client when the user - passes a Chat instance. - - """ return { "messages": [ - self._create_message(message["role"], message["content"]) + self._create_anthropic_message(message) for message in model_input.messages ] } - def _create_message(self, role: str, content: str | list) -> dict: - """Create a message.""" + def _create_anthropic_message(self, message: Message) -> dict: + """Create a message for the Anthropic client.""" + role = message.get("role", None) + content = message.get("content", None) + tool_calls: Optional[List[ToolCall]] = message.get("tool_calls", None) # type: ignore + tool_call_id = message.get("tool_call_id", None) - if isinstance(content, str): + if role == "system": + raise ValueError( + "System messages are not supported in Chat inputs for " + + "Anthropic. Use the `system` inference argument instead." + ) + elif role in ["user", "assistant"]: + if role == "assistant" and (content is None and tool_calls is None): + raise ValueError( + "Either content or tool calls is required for " + + "assistant messages" + ) + elif role == "user" and content is None: + raise ValueError(f"Content is required for {role} messages") + formatted_content = self._create_anthropic_content(content, tool_calls) return { "role": role, - "content": content, + "content": formatted_content, } + elif role == "tool": + if content is None or tool_call_id is None: + raise ValueError( + "Content and tool call id are required for " + + "tool messages" + ) + return { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_call_id, + "content": content, + } + ] + } + else: + raise ValueError( + f"Invalid message role: {role}. The role must be one of " + + "'user', 'assistant' or 'tool'." + ) + def _create_anthropic_content( + self, + content: str | list | None, + tool_calls: List[ToolCall] | None + ) -> list | None: + """Create the content for an Anthropic message.""" + content_parts = [] + if isinstance(content, str): + content_parts.append(self._create_anthropic_text_content_part(content)) elif isinstance(content, list): - prompt = content[0] + text = content[0] images = content[1:] - if not all(isinstance(image, Image) for image in images): raise ValueError("All assets provided must be of type Image") - - image_content_messages = [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": image.image_format, - "data": image.image_str, - }, - } - for image in images - ] - - return { - "role": role, - "content": [ - *image_content_messages, - {"type": "text", "text": prompt}, - ], - } - + content_parts.append(self._create_anthropic_text_content_part(text)) + content_parts.extend([self._create_anthropic_img_content_part(image) for image in images]) + elif not content: + pass else: raise ValueError( f"Invalid content type: {type(content)}. " @@ -114,6 +153,38 @@ def _create_message(self, role: str, content: str | list) -> dict: "and a list of images." ) + if tool_calls: + content_parts.extend([self._create_anthropic_tool_content_part(tool) for tool in tool_calls]) + + return content_parts + + def _create_anthropic_text_content_part(self, content: str) -> dict: + """Create a content part for a text input.""" + return { + "type": "text", + "text": content, + } + + def _create_anthropic_img_content_part(self, image: Image) -> dict: + """Create a content part for an image input.""" + return { + "type": "image", + "source": { + "type": "base64", + "media_type": image.image_format, + "data": image.image_str, + }, + } + + def _create_anthropic_tool_content_part(self, tool: ToolCall) -> dict: + """Create a content part for a tool call.""" + return { + "type": "tool_use", + "id": tool["tool_call_id"], + "name": tool["tool_name"], + "input": tool["args"], + } + def format_output_type(self, output_type): """Not implemented for Anthropic.""" if output_type is None: @@ -124,6 +195,37 @@ def format_output_type(self, output_type): "Anthropic." ) + def format_tools(self, tools: Optional[List[ToolDef]]) -> Optional[list]: + """Format the tools for the Anthropic client. + + Parameters + ---------- + tools + A list of ToolDef instances. + + Returns + ------- + list + The formatted tools to pass to the Anthropic client. + + """ + if not tools: + return None + + formatted_tools = [] + for tool in tools: + formatted_tools.append({ + "name": tool["name"], + "description": tool["description"], + "input_schema": { + "type": "object", + "properties": tool["parameters"], + "required": tool["required"], + }, + }) + + return formatted_tools + class Anthropic(Model): """Thin wrapper around the `anthropic.Anthropic` client. @@ -151,9 +253,10 @@ def __init__( def generate( self, model_input: Union[Chat, list, str], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> str: + ) -> Output: """Generate text using Anthropic. Parameters @@ -164,16 +267,19 @@ def generate( As structured generation is not supported by Anthropic, the value of this argument must be `None`. Otherwise, an error will be raised at runtime. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - str + Output The response generated by the model. """ messages = self.type_adapter.format_input(model_input) + tools = self.type_adapter.format_tools(tools) if output_type is not None: raise NotImplementedError( @@ -186,16 +292,21 @@ def generate( ): inference_kwargs["model"] = self.model_name + if tools: + inference_kwargs["tools"] = tools + completion = self.client.messages.create( **messages, **inference_kwargs, ) - return completion.content[0].text + + return self._handle_anthropic_response(completion) def generate_batch( self, model_input, - output_type = None, + output_type, + tools: Optional[List[ToolDef]], **inference_kwargs, ): raise NotImplementedError( @@ -205,9 +316,10 @@ def generate_batch( def generate_stream( self, model_input: Union[Chat, list, str], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using Anthropic. Parameters @@ -218,16 +330,19 @@ def generate_stream( As structured generation is not supported by Anthropic, the value of this argument must be `None`. Otherwise, an error will be raised at runtime. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] + Iterator[StreamingOutput] An iterator that yields the text generated by the model. """ messages = self.type_adapter.format_input(model_input) + tools = self.type_adapter.format_tools(tools) if output_type is not None: raise NotImplementedError( @@ -240,18 +355,99 @@ def generate_stream( ): inference_kwargs["model"] = self.model_name + if tools: + inference_kwargs["tools"] = tools + stream = self.client.messages.create( **messages, stream=True, **inference_kwargs, ) + yield from self._process_streaming_chunks(stream) + + def _process_streaming_chunks( + self, stream: Iterator[Any] + ) -> Iterator[StreamingOutput]: + """Process streaming chunks from Anthropic API and convert them to + StreamingOutput instances. + + Parameters + ---------- + stream + The stream from the Anthropic API. + + Yields + ------ + Iterator[StreamingOutput] + An iterator that yields the StreamingOutput instances. + + """ + # This is needed as Anthropic first provide a chunk for the type of the + # block to follow and then only the content of the block. + current_block_type = None + tool_call_id = None + tool_name = None + for chunk in stream: - if ( - chunk.type == "content_block_delta" - and chunk.delta.type == "text_delta" - ): - yield chunk.delta.text + if chunk.type == "content_block_start": + if chunk.content_block.type == "text": + current_block_type = "text" + elif chunk.content_block.type == "tool_use": + current_block_type = "tool_use" + tool_call_id = chunk.content_block.id + tool_name = chunk.content_block.name + + elif chunk.type == "content_block_delta": + if current_block_type == "text": + yield StreamingOutput( + content=chunk.delta.text + ) + elif current_block_type == "tool_use": + yield StreamingOutput( + tool_calls=[ + StreamingToolCallOutput( + name=tool_name, # type: ignore + args=str(chunk.delta.partial_json), + id=tool_call_id + ) + ], + ) + + def _handle_anthropic_response(self, response) -> Output: + """Convert the response from the Anthropic API to an Output. + + Parameters + ---------- + response + The response from the Anthropic API. + + Returns + ------- + Output + The Output. + + """ + content = None + tool_calls = [] + + if hasattr(response, 'content') and response.content: + for content_block in response.content: + if content_block.type == "tool_use": + tool_calls.append( + ToolCallOutput( + name=content_block.name, + args=content_block.input, + id=content_block.id + ) + ) + elif content_block.type == "text": + content = content_block.text + + return Output( + content=content, + tool_calls=tool_calls or None, + ) def from_anthropic( @@ -264,7 +460,7 @@ def from_anthropic( ---------- client An `anthropic.Anthropic` client instance. - model_name + model_name The name of the model to use. Returns diff --git a/outlines/models/base.py b/outlines/models/base.py index 2ad0407f3..0534aa669 100644 --- a/outlines/models/base.py +++ b/outlines/models/base.py @@ -3,6 +3,9 @@ from abc import ABC, abstractmethod from typing import Any, AsyncIterator, Iterator, List, Optional +from outlines.tools import ToolDef, ToolsInput +from outlines.outputs import Output, StreamingOutput + class ModelTypeAdapter(ABC): """Base class for all model type adapters. @@ -39,7 +42,7 @@ def format_input(self, model_input: Any) -> Any: ... @abstractmethod - def format_output_type(self, output_type: Optional[Any] = None) -> Any: + def format_output_type(self, output_type: Optional[Any]) -> Any: """Format the output type to the expected format of the model. For black-box models, this typically means creating a `response_format` @@ -59,6 +62,24 @@ def format_output_type(self, output_type: Optional[Any] = None) -> Any: """ ... + @abstractmethod + def format_tools(self, tools: Optional[List[ToolDef]]) -> Optional[list]: + """Format the tools to the expected format of the model. + + Parameters + ---------- + tools + A list of tools to format. + + Returns + ------- + Optional[list] + The formatted tools to be passed to the model. If no tools are + provided, returns `None`. + + """ + ... + class Model(ABC): """Base class for all synchronous models. @@ -82,8 +103,10 @@ def __call__( model_input: Any, output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + tools: Optional[ToolsInput] = None, **inference_kwargs: Any - ) -> Any: + ) -> Output | List[Output]: """Call the model. Users can call the model directly, in which case we will create a @@ -108,26 +131,31 @@ def __call__( The name of the backend to use to create the logits processor that will be used to generate the response. Only used for steerable models if `output_type` is provided. + tools + A list of tools to provide to the model. Can contain an MCPServer, + a list of ToolDef, Callable, or BaseModel instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - Any - The response generated by the model. + Output | List[Output] + The output generated by the model. """ from outlines.generator import Generator - return Generator(self, output_type, backend)(model_input, **inference_kwargs) + return Generator(self, output_type, backend, tools=tools)(model_input, **inference_kwargs) # type: ignore def batch( self, model_input: List[Any], output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + tools: Optional[ToolsInput] = None, **inference_kwargs: Any - ) -> List[Any]: + ) -> List[Output] | List[List[Output]]: """Make a batch call to the model (several inputs at once). Users can use the `batch` method from the model directly, in which @@ -153,18 +181,21 @@ def batch( The name of the backend to use to create the logits processor that will be used to generate the response. Only used for steerable models if `output_type` is provided. + tools + A list of tools to provide to the model. Can contain an MCPServer, + a list of ToolDef, Callable, or BaseModel instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - List[Any] - The list of responses generated by the model. + List[Output] | List[List[Output]] + The list of outputs generated by the model. """ - from outlines import Generator + from outlines.generator import Generator - generator = Generator(self, output_type, backend) + generator = Generator(self, output_type, backend, tools=tools) return generator.batch(model_input, **inference_kwargs) # type: ignore def stream( @@ -172,8 +203,10 @@ def stream( model_input: Any, output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + tools: Optional[ToolsInput] = None, **inference_kwargs: Any - ) -> Iterator[Any]: + ) -> Iterator[StreamingOutput]: """Stream a response from the model. Users can use the `stream` method from the model directly, in which @@ -201,27 +234,31 @@ def stream( The name of the backend to use to create the logits processor that will be used to generate the response. Only used for steerable models if `output_type` is provided. + tools + A list of tools to provide to the model. Can contain an MCPServer, + a list of ToolDef, Callable, or BaseModel instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - Iterator[Any] - A stream of responses from the model. + Iterator[StreamingOutput] + A stream of StreamingOutput produced by the model. """ - from outlines import Generator + from outlines.generator import Generator - generator = Generator(self, output_type, backend) + generator = Generator(self, output_type, backend, tools=tools) return generator.stream(model_input, **inference_kwargs) # type: ignore @abstractmethod def generate( self, model_input: Any, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any - ) -> Any: + ) -> Output | List[Output]: """Generate a response from the model. The output_type argument contains a logits processor for steerable @@ -234,13 +271,15 @@ def generate( The input provided by the user. output_type The output type provided by the user. + tools + A list of ToolDef instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - Any - The response generated by the model. + Output | List[Output] + The output generated by the model. """ ... @@ -249,9 +288,10 @@ def generate( def generate_batch( self, model_input: List[Any], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any - ) -> List[Any]: + ) -> List[Output] | List[List[Output]]: """Generate a batch of responses from the model. The output_type argument contains a logits processor for steerable @@ -264,23 +304,27 @@ def generate_batch( The list of inputs provided by the user. output_type The output type provided by the user. + tools + A list of ToolDef instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - List[Any] - The list of responses generated by the model. + List[Output] | List[List[Output]] + The list of outputs generated by the model. """ ... + @abstractmethod def generate_stream( self, model_input: Any, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any - ) -> Iterator[Any]: + ) -> Iterator[StreamingOutput]: """Generate a stream of responses from the model. The output_type argument contains a logits processor for steerable @@ -293,13 +337,15 @@ def generate_stream( The input provided by the user. output_type The output type provided by the user. + tools + A list of ToolDef instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - Iterator[Any] - A stream of responses from the model. + Iterator[StreamingOutput] + A stream of StreamingOutput produced by the model. """ ... @@ -327,8 +373,10 @@ async def __call__( model_input: Any, output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + tools: Optional[ToolsInput] = None, **inference_kwargs: Any - ) -> Any: + ) -> Output | List[Output]: """Call the model. Users can call the model directly, in which case we will create a @@ -353,27 +401,32 @@ async def __call__( The name of the backend to use to create the logits processor that will be used to generate the response. Only used for steerable models if `output_type` is provided. + tools + A list of tools to provide to the model. Can contain an MCPServer, + a list of ToolDef, Callable, or BaseModel instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - Any - The response generated by the model. + Output | List[Output] + The output generated by the model. """ - from outlines import Generator + from outlines.generator import Generator - generator = Generator(self, output_type, backend) - return await generator(model_input, **inference_kwargs) + generator = Generator(self, output_type, backend, tools=tools) + return await generator(model_input, **inference_kwargs) # type: ignore async def batch( self, model_input: List[Any], output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + tools: Optional[ToolsInput] = None, **inference_kwargs: Any - ) -> List[Any]: + ) -> List[Output] | List[List[Output]]: """Make a batch call to the model (several inputs at once). Users can use the `batch` method from the model directly, in which @@ -399,18 +452,21 @@ async def batch( The name of the backend to use to create the logits processor that will be used to generate the response. Only used for steerable models if `output_type` is provided. + tools + A list of tools to provide to the model. Can contain an MCPServer, + a list of ToolDef, Callable, or BaseModel instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - List[Any] - The list of responses generated by the model. + List[Output] | List[List[Output]] + The list of outputs generated by the model. """ - from outlines import Generator + from outlines.generator import Generator - generator = Generator(self, output_type, backend) + generator = Generator(self, output_type, backend, tools=tools) return await generator.batch(model_input, **inference_kwargs) # type: ignore async def stream( @@ -418,8 +474,10 @@ async def stream( model_input: Any, output_type: Optional[Any] = None, backend: Optional[str] = None, + *, + tools: Optional[ToolsInput] = None, **inference_kwargs: Any - ) -> AsyncIterator[Any]: + ) -> AsyncIterator[StreamingOutput]: """Stream a response from the model. Users can use the `stream` method from the model directly, in which @@ -447,18 +505,22 @@ async def stream( The name of the backend to use to create the logits processor that will be used to generate the response. Only used for steerable models if `output_type` is provided. + tools + A list of tools to provide to the model. Can contain an MCPServer, + a list of ToolDef, Callable, or BaseModel instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - AsyncIterator[Any] - A stream of responses from the model. + AsyncIterator[StreamingOutput] + A coroutine that will produce an async iterator of StreamingOutput + produced by the model. """ - from outlines import Generator + from outlines.generator import Generator - generator = Generator(self, output_type, backend) + generator = Generator(self, output_type, backend, tools=tools) async for chunk in generator.stream(model_input, **inference_kwargs): # type: ignore yield chunk @@ -467,9 +529,10 @@ async def stream( async def generate( self, model_input: Any, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any - ) -> Any: + ) -> Output | List[Output]: """Generate a response from the model. The output_type argument contains a logits processor for steerable @@ -482,13 +545,15 @@ async def generate( The input provided by the user. output_type The output type provided by the user. + tools + A list of ToolDef instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - Any - The response generated by the model. + Output | List[Output] + The output generated by the model. """ ... @@ -497,9 +562,10 @@ async def generate( async def generate_batch( self, model_input: List[Any], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any - ) -> List[Any]: + ) -> List[Output] | List[List[Output]]: """Generate a batch of responses from the model. The output_type argument contains a logits processor for steerable @@ -512,13 +578,15 @@ async def generate_batch( The list of inputs provided by the user. output_type The output type provided by the user. + tools + A list of ToolDef instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - List[Any] - The list of responses generated by the model. + List[Output] | List[List[Output]] + The list of outputs generated by the model. """ ... @@ -527,9 +595,10 @@ async def generate_batch( async def generate_stream( self, model_input: Any, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any - ) -> AsyncIterator[Any]: + ) -> AsyncIterator[StreamingOutput]: """Generate a stream of responses from the model. The output_type argument contains a logits processor for steerable @@ -542,13 +611,16 @@ async def generate_stream( The input provided by the user. output_type The output type provided by the user. + tools + A list of ToolDef instances. **inference_kwargs Additional keyword arguments to pass to the model. Returns ------- - AsyncIterator[Any] - A coroutine that will produce an async iterator of responses from the model. + AsyncIterator[StreamingOutput] + A coroutine that will produce an async iterator of StreamingOutput + produced by the model. """ ... diff --git a/outlines/models/dottxt.py b/outlines/models/dottxt.py index f88d3d5d6..57749de9b 100644 --- a/outlines/models/dottxt.py +++ b/outlines/models/dottxt.py @@ -1,11 +1,13 @@ """Integration with Dottxt's API.""" import json -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, List from pydantic import TypeAdapter from outlines.models.base import Model, ModelTypeAdapter +from outlines.outputs import Output +from outlines.tools import ToolDef from outlines.types import CFG, JsonSchema, Regex from outlines.types.utils import ( is_dataclass, @@ -44,7 +46,7 @@ def format_input(self, model_input: str) -> str: "The only available type is `str`." ) - def format_output_type(self, output_type: Optional[Any] = None) -> str: + def format_output_type(self, output_type: Optional[Any]) -> str: """Format the output type to pass to the client. TODO: `int`, `float` and other Python types could be supported via @@ -98,6 +100,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> str: "Consider using a local mode instead." ) + def format_tools(self, tools): + """Not implemented for Dottxt.""" + if tools: + raise NotImplementedError( + "Dottxt does not support tools." + ) + class Dottxt(Model): """Thin wrapper around the `dottxt.client.Dottxt` client. @@ -132,9 +141,10 @@ def __init__( def generate( self, model_input: str, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> str: + ) -> Output: """Generate text using Dottxt. Parameters @@ -145,15 +155,18 @@ def generate( The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - str + Output The text generated by the model. """ + self.type_adapter.format_tools(tools) prompt = self.type_adapter.format_input(model_input) json_schema = self.type_adapter.format_output_type(output_type) @@ -174,14 +187,17 @@ def generate( json_schema, **inference_kwargs, ) - return completion.data + + return Output(content=completion.data) def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): + """Not available for Dottxt.""" raise NotImplementedError( "Dottxt does not support batch generation." ) @@ -189,7 +205,8 @@ def generate_batch( def generate_stream( self, model_input, - output_type=None, + output_type, + tools, **inference_kwargs, ): """Not available for Dottxt.""" diff --git a/outlines/models/gemini.py b/outlines/models/gemini.py index 742dfa7a1..78e51dc27 100644 --- a/outlines/models/gemini.py +++ b/outlines/models/gemini.py @@ -8,11 +8,26 @@ Optional, Union, get_args, + List, ) -from outlines.inputs import Image, Chat +from outlines.inputs import ( + Chat, + Image, + Message, + UserMessage, + ToolCall, + ToolMessage, +) from outlines.models.base import Model, ModelTypeAdapter +from outlines.outputs import ( + Output, + StreamingOutput, + StreamingToolCallOutput, + ToolCallOutput +) from outlines.types import CFG, Choice, JsonSchema, Regex +from outlines.tools import ToolDef from outlines.types.utils import ( is_dataclass, is_enum, @@ -63,62 +78,112 @@ def format_input(self, model_input): @format_input.register(str) def format_str_model_input(self, model_input: str) -> dict: - return {"contents": [self._create_text_part(model_input)]} + return { + "contents": [ + self._create_message(UserMessage( + role="user", + content=model_input + )) + ] + } @format_input.register(list) def format_list_model_input(self, model_input: list) -> dict: return { "contents": [ - self._create_message("user", model_input) + self._create_message(UserMessage( + role="user", + content=model_input + )) ] } @format_input.register(Chat) def format_chat_model_input(self, model_input: Chat) -> dict: - """Generate the `contents` argument to pass to the client when the user - passes a Chat instance. - - """ return { "contents": [ - self._create_message(message["role"], message["content"]) + self._create_message(message) for message in model_input.messages ] } - def _create_message(self, role: str, content: str | list) -> dict: - """Create a message.""" + def _create_message(self, message: Message) -> dict: + """Create a Gemini message.""" + role = message.get("role", None) + content = message.get("content", None) + tool_calls = message.get("tool_calls", None) + tool_name = message.get("tool_name", None) - # Gemini uses "model" instead of "assistant" - if role == "assistant": - role = "model" + content_parts = self._create_content_parts(content) + tool_call_parts = self._create_tool_call_parts( + tool_calls if isinstance(tool_calls, list) else None + ) - if isinstance(content, str): + if role == "system": + raise ValueError( + "System messages are not supported in Chat inputs for " + + "Gemini. Use the `system_instruction` inference argument " + + "instead." + ) + elif role == "user": + if not content: + raise ValueError( + "Content is required for user messages" + ) return { "role": role, - "parts": [self._create_text_part(content)], + "parts": content_parts, + } + elif role == "assistant": + if not content and not tool_calls: + raise ValueError( + "Either content or tool calls is required for " + + "assistant messages" + ) + return { + "role": "model", + "parts": [ + *content_parts, + *tool_call_parts, + ], + } + elif role == "tool": + if not content or not tool_name: + raise ValueError( + "Content and tool name are required for " + + "tool messages" + ) + return { + "role": "user", + "parts": [self._create_tool_response_part(message)], # type: ignore } + else: + raise ValueError( + f"Invalid message role: {role}. " + "The role must be one of 'user', 'assistant' or 'tool'." + ) + def _create_content_parts( + self, content: Optional[str | list] + ) -> List[dict]: + """Create Gemini message parts from a content.""" + if content is None: + return [] + if isinstance(content, str): + return [self._create_text_part(content)] elif isinstance(content, list): - prompt = content[0] + text = content[0] images = content[1:] - if not all(isinstance(image, Image) for image in images): raise ValueError("All assets provided must be of type Image") - image_parts = [ self._create_img_part(image) for image in images ] - - return { - "role": role, - "parts": [ - self._create_text_part(prompt), - *image_parts, - ], - } - + return [ + self._create_text_part(text), + *image_parts, + ] else: raise ValueError( f"Invalid content type: {type(content)}. " @@ -126,8 +191,15 @@ def _create_message(self, role: str, content: str | list) -> dict: "and a list of images." ) - return {"contents": [prompt, *image_parts]} - + def _create_tool_call_parts(self, tool_calls: Optional[List[ToolCall]]) -> List[dict]: + """Create Gemini message parts from tool calls.""" + if tool_calls is None: + return [] + else: + return [ + self._create_tool_call_part(tool_call) + for tool_call in tool_calls + ] def _create_text_part(self, text: str) -> dict: """Create a text input part for a message.""" @@ -144,7 +216,27 @@ def _create_img_part(self, image: Image) -> dict: } } - def format_output_type(self, output_type: Optional[Any] = None) -> dict: + def _create_tool_call_part(self, tool_call: ToolCall) -> dict: + """Create a tool call input part for a message.""" + return { + "function_call": { + "id": tool_call["tool_call_id"], + "name": tool_call["tool_name"], + "args": tool_call["args"], + } + } + + def _create_tool_response_part(self, tool_message: ToolMessage) -> dict: + """Create a tool response input part for a message.""" + return { + "function_response": { + "id": tool_message.get("tool_call_id", None), + "name": tool_message["tool_name"], + "response": tool_message["content"], + } + } + + def format_output_type(self, output_type: Optional[Any]) -> dict: """Generate the `generation_config` argument to pass to the client. Parameters @@ -256,6 +348,40 @@ def format_list_output_type(self, output_type: Optional[Any]) -> dict: f"Got {output_type} instead." ) + def format_tools(self, tools: Optional[List[ToolDef]]) -> Optional[list]: + """Format the tools for the Gemini client. + + Parameters + ---------- + tools + A list of ToolDef instances. + + Returns + ------- + Optional[list] + The formatted tools to pass to the Gemini client. If no tools are + provided, returns `None`. + + """ + if not tools: + return None + + formatted_tools = [] + for tool in tools: + formatted_tools.append({ + "function_declarations": [{ + "name": tool["name"], + "description": tool["description"], + "parameters": { + "type": "object", + "properties": tool["parameters"], + "required": tool["required"], + }, + }] + }) + + return formatted_tools + class Gemini(Model): """Thin wrapper around the `google.genai.Client` client. @@ -283,9 +409,10 @@ def __init__(self, client: "Client", model_name: Optional[str] = None): def generate( self, model_input: Union[Chat, list, str], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs, - ) -> str: + ) -> Output: """Generate a response from the model. Parameters @@ -296,30 +423,39 @@ def generate( The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema, a list of such types, or a multiple choice type. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - str + Output The response generated by the model. """ contents = self.type_adapter.format_input(model_input) generation_config = self.type_adapter.format_output_type(output_type) + tools = self.type_adapter.format_tools(tools) + + inference_kwargs.update(**generation_config) + + if tools: + inference_kwargs["tools"] = tools completion = self.client.models.generate_content( **contents, model=inference_kwargs.pop("model", self.model_name), - config={**generation_config, **inference_kwargs} + config=inference_kwargs ) - return completion.text + return self._handle_gemini_response(completion) def generate_batch( self, model_input, - output_type = None, + output_type, + tools: Optional[List[ToolDef]], **inference_kwargs, ): raise NotImplementedError( @@ -329,9 +465,10 @@ def generate_batch( def generate_stream( self, model_input: Union[Chat, list, str], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Generate a stream of responses from the model. Parameters @@ -342,17 +479,26 @@ def generate_stream( The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema, a list of such types, or a multiple choice type. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] - An iterator that yields the text generated by the model. + Iterator[StreamingOutput] + An iterator that yields the StreamingOutput generated by the model. """ contents = self.type_adapter.format_input(model_input) generation_config = self.type_adapter.format_output_type(output_type) + tools = self.type_adapter.format_tools(tools) + + if "model" not in inference_kwargs and self.model_name is not None: + inference_kwargs["model"] = self.model_name + + if tools: + generation_config["tools"] = tools stream = self.client.models.generate_content_stream( **contents, @@ -361,8 +507,79 @@ def generate_stream( ) for chunk in stream: - if hasattr(chunk, "text") and chunk.text: - yield chunk.text + streaming_output = self._handle_gemini_stream_chunk(chunk) + if streaming_output is not None: + yield streaming_output + + def _handle_gemini_response(self, response) -> Output: + """Convert the response from the Gemini API to an Output. + + Parameters + ---------- + response + The response from the Gemini API. + + Returns + ------- + Output + The Output. + + """ + if hasattr(response, "candidates") and response.candidates: + candidate = response.candidates[0] + if hasattr(candidate, "content") and candidate.content: + tool_calls = [] + content = None + + for part in candidate.content.parts: + if hasattr(part, "function_call") and part.function_call: + tool_calls.append( + ToolCallOutput( + name=part.function_call.name, + args=part.function_call.args, + ) + ) + elif hasattr(part, "text") and part.text: + content = part.text + + return Output( + content=content, + tool_calls=tool_calls, # type: ignore + ) + + return Output(content=response.text) + + def _handle_gemini_stream_chunk(self, chunk) -> Optional[StreamingOutput]: + """Convert the streaming chunk from the Gemini API to a StreamingOutput. + + Parameters + ---------- + chunk + The streaming chunk from the Gemini API. + + Returns + ------- + Optional[StreamingOutput] + The text generated by the model. + + """ + if hasattr(chunk, "candidates") and chunk.candidates: + candidate = chunk.candidates[0] + if hasattr(candidate, "content") and candidate.content: + for part in candidate.content.parts: + if hasattr(part, "function_call") and part.function_call: + return StreamingOutput( + tool_calls=[ + StreamingToolCallOutput( + name=part.function_call.name, + args=str(part.function_call.args), + ) + ], + ) + elif hasattr(part, "text") and part.text: + return StreamingOutput(content=part.text) + + return None def from_gemini(client: "Client", model_name: Optional[str] = None) -> Gemini: diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index c6a3bf8b0..5db64716a 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -17,7 +17,9 @@ from outlines.inputs import Chat from outlines.models.base import Model, ModelTypeAdapter from outlines.models.tokenizer import Tokenizer +from outlines.outputs import Output, StreamingOutput from outlines.processors import OutlinesLogitsProcessor +from outlines.tools import ToolDef if TYPE_CHECKING: from llama_cpp import Llama, LogitsProcessorList @@ -196,7 +198,7 @@ def format_chat_input(self, model_input: Chat) -> list: ] def format_output_type( - self, output_type: Optional[OutlinesLogitsProcessor] = None, + self, output_type: Optional[OutlinesLogitsProcessor], ) -> "LogitsProcessorList": """Generate the logits processor argument to pass to the model. @@ -215,6 +217,13 @@ def format_output_type( return LogitsProcessorList([output_type]) + def format_tools(self, tools): + """Not available for LlamaCpp.""" + if tools: + raise NotImplementedError( + "LlamaCpp does not support tools." + ) + class LlamaCpp(Model): """Thin wrapper around the `llama_cpp.Llama` model. @@ -240,9 +249,10 @@ def __init__(self, model: "Llama"): def generate( self, model_input: Union[Chat, str], - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> str: + ) -> Output: """Generate text using `llama-cpp-python`. Parameters @@ -252,6 +262,8 @@ def generate( output_type The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the `Llama.__call__` method of the `llama-cpp-python` library. @@ -262,41 +274,46 @@ def generate( The text generated by the model. """ + self.type_adapter.format_tools(tools) prompt = self.type_adapter.format_input(model_input) + logits_processor = self.type_adapter.format_output_type(output_type) if isinstance(prompt, str): completion = self.model( prompt, - logits_processor=self.type_adapter.format_output_type(output_type), + logits_processor=logits_processor, **inference_kwargs, ) result = completion["choices"][0]["text"] elif isinstance(prompt, list): # pragma: no cover completion = self.model.create_chat_completion( prompt, - logits_processor=self.type_adapter.format_output_type(output_type), + logits_processor=logits_processor, **inference_kwargs, ) result = completion["choices"][0]["message"]["content"] self.model.reset() - return result + return Output(content=result) def generate_batch( self, model_input, - output_type = None, + output_type, **inference_kwargs, ): - raise NotImplementedError("LlamaCpp does not support batch generation.") + raise NotImplementedError( + "LlamaCpp does not support batch generation." + ) def generate_stream( self, model_input: Union[Chat, str], - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using `llama-cpp-python`. Parameters @@ -306,6 +323,8 @@ def generate_stream( output_type The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the `Llama.__call__` method of the `llama-cpp-python` library. @@ -316,27 +335,33 @@ def generate_stream( An iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) prompt = self.type_adapter.format_input(model_input) + logits_processor = self.type_adapter.format_output_type(output_type) if isinstance(prompt, str): generator = self.model( prompt, - logits_processor=self.type_adapter.format_output_type(output_type), + logits_processor=logits_processor, stream=True, **inference_kwargs, ) for chunk in generator: - yield chunk["choices"][0]["text"] + yield StreamingOutput( + content=chunk["choices"][0]["text"] + ) elif isinstance(prompt, list): # pragma: no cover generator = self.model.create_chat_completion( prompt, - logits_processor=self.type_adapter.format_output_type(output_type), + logits_processor=logits_processor, stream=True, **inference_kwargs, ) for chunk in generator: - yield chunk["choices"][0]["delta"].get("content", "") + yield StreamingOutput( + content=chunk["choices"][0]["delta"].get("content", "") + ) def from_llamacpp(model: "Llama"): diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index 498fec2dc..97dca71d6 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -6,7 +6,9 @@ from outlines.inputs import Chat from outlines.models.base import Model, ModelTypeAdapter from outlines.models.transformers import TransformerTokenizer +from outlines.outputs import Output, StreamingOutput from outlines.processors import OutlinesLogitsProcessor +from outlines.tools import ToolDef if TYPE_CHECKING: import mlx.nn as nn @@ -37,7 +39,7 @@ def format_input(self, model_input): """ raise NotImplementedError( - f"The input type {input} is not available with mlx-lm. " + f"The input type {model_input} is not available with mlx-lm. " "The available types are `str` and `Chat`." ) @@ -63,7 +65,7 @@ def format_chat_input(self, model_input: Chat) -> str: ) def format_output_type( - self, output_type: Optional[OutlinesLogitsProcessor] = None, + self, output_type: Optional[OutlinesLogitsProcessor], ) -> Optional[List[OutlinesLogitsProcessor]]: """Generate the logits processor argument to pass to the model. @@ -83,6 +85,14 @@ def format_output_type( return [output_type] + def format_tools(self, tools): + """Not available for MLXLM.""" + if tools: + raise NotImplementedError( + "MLXLM does not support tools." + ) + + class MLXLM(Model): """Thin wrapper around an `mlx_lm` model. @@ -118,9 +128,10 @@ def __init__( def generate( self, model_input: str, - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], + tools: Optional[List[ToolDef]], **kwargs, - ) -> str: + ) -> Output: """Generate text using `mlx-lm`. Parameters @@ -130,18 +141,22 @@ def generate( output_type The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. kwargs Additional keyword arguments to pass to the `mlx-lm` library. Returns ------- - str + Output The text generated by the model. """ from mlx_lm import generate - return generate( + self.type_adapter.format_tools(tools) + + result = generate( self.model, self.mlx_tokenizer, self.type_adapter.format_input(model_input), @@ -149,10 +164,13 @@ def generate( **kwargs, ) + return Output(content=result.text) + def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **kwargs, ): raise NotImplementedError( @@ -162,9 +180,10 @@ def generate_batch( def generate_stream( self, model_input: str, - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], + tools: Optional[List[ToolDef]], **kwargs, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using `mlx-lm`. Parameters @@ -174,17 +193,21 @@ def generate_stream( output_type The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. kwargs Additional keyword arguments to pass to the `mlx-lm` library. Returns ------- - Iterator[str] + Iterator[StreamingOutput] An iterator that yields the text generated by the model. """ from mlx_lm import stream_generate + self.type_adapter.format_tools(tools) + for gen_response in stream_generate( self.model, self.mlx_tokenizer, @@ -192,7 +215,7 @@ def generate_stream( logits_processors=self.type_adapter.format_output_type(output_type), **kwargs, ): - yield gen_response.text + yield StreamingOutput(content=gen_response.text) def from_mlxlm(model: "nn.Module", tokenizer: "PreTrainedTokenizer") -> MLXLM: diff --git a/outlines/models/ollama.py b/outlines/models/ollama.py index 24ea77eac..7693e39c0 100644 --- a/outlines/models/ollama.py +++ b/outlines/models/ollama.py @@ -2,12 +2,22 @@ import json from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Iterator, + List, + Optional, + Union, +) from pydantic import TypeAdapter from outlines.inputs import Chat, Image from outlines.models.base import AsyncModel, Model, ModelTypeAdapter +from outlines.outputs import Output, StreamingOutput +from outlines.tools import ToolDef from outlines.types import CFG, JsonSchema, Regex from outlines.types.utils import ( is_dataclass, @@ -74,7 +84,7 @@ def format_chat_model_input(self, model_input: Chat) -> list: """ return [ - self._create_message(message["role"], message["content"]) + self._create_message(message["role"], message["content"]) # type: ignore for message in model_input.messages ] @@ -107,9 +117,7 @@ def _create_message(self, role: str, content: str | list) -> dict: "and a list of images." ) - def format_output_type( - self, output_type: Optional[Any] = None - ) -> Optional[str]: + def format_output_type(self, output_type: Optional[Any]) -> Optional[str]: """Format the output type to pass to the client. TODO: `int`, `float` and other Python types could be supported via @@ -159,6 +167,13 @@ def format_output_type( "Consider using a local model instead." ) + def format_tools(self, tools): + """Not available for Ollama.""" + if tools: + raise NotImplementedError( + "Tools are not available for Ollama." + ) + class Ollama(Model): """Thin wrapper around the `ollama.Client` client. @@ -184,9 +199,10 @@ def __init__(self, client: "Client", model_name: Optional[str] = None): def generate(self, model_input: Chat | str | list, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **kwargs: Any, - ) -> str: + ) -> Output: """Generate text using Ollama. Parameters @@ -197,15 +213,19 @@ def generate(self, The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema. + tools + The tools to use for the generation. **kwargs Additional keyword arguments to pass to the client. Returns ------- - str + Output The text generated by the model. """ + self.type_adapter.format_tools(tools) + if "model" not in kwargs and self.model_name is not None: kwargs["model"] = self.model_name @@ -214,12 +234,14 @@ def generate(self, format=self.type_adapter.format_output_type(output_type), **kwargs, ) - return response.message.content + + return Output(content=response.message.content) def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **kwargs, ): raise NotImplementedError( @@ -229,9 +251,10 @@ def generate_batch( def generate_stream( self, model_input: Chat | str | list, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **kwargs: Any, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using Ollama. Parameters @@ -242,15 +265,19 @@ def generate_stream( The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema. + tools + The tools to use for the generation. **kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] + Iterator[StreamingOutput] An iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) + if "model" not in kwargs and self.model_name is not None: kwargs["model"] = self.model_name @@ -260,8 +287,9 @@ def generate_stream( stream=True, **kwargs, ) + for chunk in response: - yield chunk.message.content + yield StreamingOutput(content=chunk.message.content) class AsyncOllama(AsyncModel): @@ -290,9 +318,10 @@ def __init__( async def generate(self, model_input: Chat | str | list, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **kwargs: Any, - ) -> str: + ) -> Output: """Generate text using Ollama. Parameters @@ -303,15 +332,19 @@ async def generate(self, The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema. + tools + The tools to use for the generation. **kwargs Additional keyword arguments to pass to the client. Returns ------- - str + Output The text generated by the model. """ + self.type_adapter.format_tools(tools) + if "model" not in kwargs and self.model_name is not None: kwargs["model"] = self.model_name @@ -320,12 +353,14 @@ async def generate(self, format=self.type_adapter.format_output_type(output_type), **kwargs, ) - return response.message.content + + return Output(content=response.message.content) async def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **kwargs, ): raise NotImplementedError( @@ -335,9 +370,10 @@ async def generate_batch( async def generate_stream( # type: ignore self, model_input: Chat | str | list, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **kwargs: Any, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[StreamingOutput]: """Stream text using Ollama. Parameters @@ -348,15 +384,19 @@ async def generate_stream( # type: ignore The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema. + tools + The tools to use for the generation. **kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] + Iterator[StreamingOutput] An iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) + if "model" not in kwargs and self.model_name is not None: kwargs["model"] = self.model_name @@ -366,8 +406,9 @@ async def generate_stream( # type: ignore stream=True, **kwargs, ) + async for chunk in stream: - yield chunk.message.content + yield StreamingOutput(content=chunk.message.content) def from_ollama( diff --git a/outlines/models/openai.py b/outlines/models/openai.py index dad987b93..cbd234f98 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,11 +1,13 @@ """Integration with OpenAI's API.""" +import ast import json from typing import ( TYPE_CHECKING, Any, AsyncIterator, Iterator, + List, Optional, Union, ) @@ -13,9 +15,23 @@ from pydantic import BaseModel, TypeAdapter -from outlines.inputs import Chat, Image +from outlines.inputs import ( + AssistantMessage, + Chat, + Image, + Message, + ToolCall, + UserMessage, +) from outlines.models.base import AsyncModel, Model, ModelTypeAdapter from outlines.models.utils import set_additional_properties_false_json_schema +from outlines.outputs import ( + Output, + ToolCallOutput, + StreamingOutput, + StreamingToolCallOutput, +) +from outlines.tools import ToolDef from outlines.types import JsonSchema, Regex, CFG from outlines.types.utils import ( is_dataclass, @@ -31,6 +47,8 @@ AsyncOpenAI as AsyncOpenAIClient, AzureOpenAI as AzureOpenAIClient, AsyncAzureOpenAI as AsyncAzureOpenAIClient, + ChatCompletionChunk, + ChatCompletion, ) __all__ = ["AsyncOpenAI", "OpenAI", "from_openai"] @@ -72,7 +90,12 @@ def format_str_model_input(self, model_input: str) -> list: """ return [ - self._create_message("user", model_input) + self._create_openai_message( + UserMessage( + role="user", + content=model_input + ) + ) ] @format_input.register(list) @@ -82,7 +105,12 @@ def format_list_model_input(self, model_input: list) -> list: """ return [ - self._create_message("user", model_input) + self._create_openai_message( + UserMessage( + role="user", + content=model_input + ) + ) ] @format_input.register(Chat) @@ -92,39 +120,75 @@ def format_chat_model_input(self, model_input: Chat) -> list: """ return [ - self._create_message(message["role"], message["content"]) + self._create_openai_message(message) for message in model_input.messages ] - def _create_message(self, role: str, content: str | list) -> dict: - """Create a message.""" + def _create_openai_message(self, message: Message) -> dict: + """Create a message for the OpenAI client.""" + role = message.get("role", None) + content = message.get("content", None) + tool_calls: Optional[List[ToolCall]] = message.get("tool_calls", None) # type: ignore + tool_call_id = message.get("tool_call_id", None) - if isinstance(content, str): + formatted_content = self._create_openai_content(content) + formatted_tool_calls = self._create_openai_tool_calls(tool_calls) + + if role in ["system", "user"]: + if formatted_content is None: + raise ValueError(f"Content is required for {role} messages") return { "role": role, - "content": content, + "content": formatted_content, } + elif role == "assistant": + if formatted_content is None and formatted_tool_calls is None: + raise ValueError( + "Either content or tool calls is required for " + + f"{role} messages" + ) + formatted_message: dict[str, Any] = {"role": role} + if formatted_content: + formatted_message["content"] = formatted_content + if formatted_tool_calls: + formatted_message["tool_calls"] = formatted_tool_calls + return formatted_message + elif role == "tool": + if formatted_content is None or tool_call_id is None: + raise ValueError( + "Content and tool call id are required for " + + f"{role} messages" + ) + return { + "role": role, + "content": formatted_content, + "tool_call_id": tool_call_id, + } + else: + raise ValueError( + f"Invalid message role: {role}. The role must be one of " + + "'system', 'user', 'assistant' or 'tool'." + ) + def _create_openai_content(self, content: str | list | None) -> str | list | None: + """Create the content for an OpenAI message.""" + if content is None: + return None + if isinstance(content, str): + return content elif isinstance(content, list): - prompt = content[0] + text = content[0] images = content[1:] - if not all(isinstance(image, Image) for image in images): raise ValueError("All assets provided must be of type Image") - image_parts = [ - self._create_img_content(image) + self._create_openai_img_content_part(image) for image in images ] - - return { - "role": role, - "content": [ - {"type": "text", "text": prompt}, - *image_parts, - ], - } - + return [ + self._create_openai_text_content_part(text), + *image_parts, + ] else: raise ValueError( f"Invalid content type: {type(content)}. " @@ -132,8 +196,15 @@ def _create_message(self, role: str, content: str | list) -> dict: "and a list of images." ) - def _create_img_content(self, image: Image) -> dict: - """Create the content for an image input.""" + def _create_openai_text_content_part(self, content: str) -> dict: + """Create a content part for a text input.""" + return { + "type": "text", + "text": content, + } + + def _create_openai_img_content_part(self, image: Image) -> dict: + """Create a content part for an image input.""" return { "type": "image_url", "image_url": { @@ -141,7 +212,25 @@ def _create_img_content(self, image: Image) -> dict: }, } - def format_output_type(self, output_type: Optional[Any] = None) -> dict: + def _create_openai_tool_calls( + self, tool_calls: List[ToolCall] | None + ) -> list | None: + """Create the tool calls argument for an OpenAI message.""" + if tool_calls is None: + return None + return [ + { + "type": "function", + "id": tool_call["tool_call_id"], + "function": { + "name": tool_call["tool_name"], + "arguments": str(tool_call["args"]), + }, + } + for tool_call in tool_calls + ] + + def format_output_type(self, output_type: Optional[Any]) -> dict: """Generate the `response_format` argument to the client based on the output type specified by the user. @@ -224,6 +313,41 @@ def format_json_mode_type(self) -> dict: """ return {"response_format": {"type": "json_object"}} + def format_tools(self, tools: Optional[List[ToolDef]]) -> Optional[list]: + """Format the tools for the OpenAI client. + + Parameters + ---------- + tools + A list of ToolDef instances. + + Returns + ------- + Optional[list] + The formatted tools to pass to the OpenAI client. If no tools are + provided, returns `None`. + + """ + if not tools: + return None + + formatted_tools = [] + for tool in tools: + formatted_tools.append({ + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": { + "type": "object", + "properties": tool["parameters"], + "required": tool["required"], + }, + }, + }) + + return formatted_tools + class OpenAI(Model): """Thin wrapper around the `openai.OpenAI` client. @@ -254,9 +378,10 @@ def __init__( def generate( self, model_input: Union[Chat, list, str], - output_type: Optional[Union[type[BaseModel], str]] = None, + output_type: Optional[Union[type[BaseModel], str]], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, list[str]]: + ) -> Output | list[Output]: """Generate text using OpenAI. Parameters @@ -267,22 +392,27 @@ def generate( The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema or an empty dictionary. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Union[str, list[str]] - The text generated by the model. + Output | list[Output] + The response generated by the model. """ import openai messages = self.type_adapter.format_input(model_input) response_format = self.type_adapter.format_output_type(output_type) + tools = self.type_adapter.format_tools(tools) if "model" not in inference_kwargs and self.model_name is not None: inference_kwargs["model"] = self.model_name + if tools: + inference_kwargs["tools"] = tools try: result = self.client.chat.completions.create( @@ -299,22 +429,13 @@ def generate( else: raise e - messages = [choice.message for choice in result.choices] - for message in messages: - if message.refusal is not None: - raise ValueError( - f"OpenAI refused to answer the request: {message.refusal}" - ) - - if len(messages) == 1: - return messages[0].content - else: - return [message.content for message in messages] + return _handle_openai_response(result) def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError( @@ -324,9 +445,10 @@ def generate_batch( def generate_stream( self, model_input: Union[Chat, list, str], - output_type: Optional[Union[type[BaseModel], str]] = None, + output_type: Optional[Union[type[BaseModel], str]], + tools: Optional[List[ToolDef]], **inference_kwargs, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using OpenAI. Parameters @@ -337,22 +459,27 @@ def generate_stream( The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema or an empty dictionary. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] - An iterator that yields the text generated by the model. + Iterator[StreamingOutput] + An iterator that yields the StreamingOutput instances. """ import openai messages = self.type_adapter.format_input(model_input) response_format = self.type_adapter.format_output_type(output_type) + tools = self.type_adapter.format_tools(tools) if "model" not in inference_kwargs and self.model_name is not None: inference_kwargs["model"] = self.model_name + if tools: + inference_kwargs["tools"] = tools try: stream = self.client.chat.completions.create( @@ -370,9 +497,50 @@ def generate_stream( else: raise e + yield from self._handle_streaming_response(stream) + + def _handle_streaming_response( + self, stream: Iterator["ChatCompletionChunk"] + ) -> Iterator[StreamingOutput]: + """Handle streaming response from OpenAI API. + + Parameters + ---------- + stream + The streaming response from OpenAI API. + + Yields + ------ + Iterator[StreamingOutput] + An iterator that yields the StreamingOutput instances. + + """ + # This is needed as OpenAI provides the tool name and call id only in + # the first delta for each tool call. + tool_name = None + tool_call_id = None + for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content is not None: - yield chunk.choices[0].delta.content + delta = chunk.choices[0].delta + if delta.tool_calls: + # When using streaming, only one tool call is returned at a + # time. + if delta.tool_calls[0].function.name: + tool_name = delta.tool_calls[0].function.name + if delta.tool_calls[0].id: + tool_call_id = delta.tool_calls[0].id + yield StreamingOutput( + content=delta.content, + tool_calls=[ + StreamingToolCallOutput( + name=tool_name or "", + args=delta.tool_calls[0].function.arguments, + id=tool_call_id + ) + ], + ) + elif delta.content is not None: + yield StreamingOutput(content=delta.content) class AsyncOpenAI(AsyncModel): @@ -404,9 +572,10 @@ def __init__( async def generate( self, model_input: Union[Chat, list, str], - output_type: Optional[Union[type[BaseModel], str]] = None, + output_type: Optional[Union[type[BaseModel], str]], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, list[str]]: + ) -> Output | list[Output]: """Generate text using OpenAI. Parameters @@ -417,22 +586,27 @@ async def generate( The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema or an empty dictionary. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Union[str, list[str]] - The text generated by the model. + Output | list[Output] + The response generated by the model. """ import openai messages = self.type_adapter.format_input(model_input) response_format = self.type_adapter.format_output_type(output_type) + tools = self.type_adapter.format_tools(tools) if "model" not in inference_kwargs and self.model_name is not None: inference_kwargs["model"] = self.model_name + if tools: + inference_kwargs["tools"] = tools try: result = await self.client.chat.completions.create( @@ -449,22 +623,13 @@ async def generate( else: raise e - messages = [choice.message for choice in result.choices] - for message in messages: - if message.refusal is not None: - raise ValueError( - f"OpenAI refused to answer the request: {message.refusal}" - ) - - if len(messages) == 1: - return messages[0].content - else: - return [message.content for message in messages] + return _handle_openai_response(result) async def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError( @@ -474,9 +639,10 @@ async def generate_batch( async def generate_stream( # type: ignore self, model_input: Union[Chat, list, str], - output_type: Optional[Union[type[BaseModel], str]] = None, + output_type: Optional[Union[type[BaseModel], str]], + tools: Optional[List[ToolDef]], **inference_kwargs, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[StreamingOutput]: """Stream text using OpenAI. Parameters @@ -487,22 +653,27 @@ async def generate_stream( # type: ignore The desired format of the response generated by the model. The output type must be of a type that can be converted to a JSON schema or an empty dictionary. + tools + The tools to use for the generation. **inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] - An iterator that yields the text generated by the model. + AsyncIterator[StreamingOutput] + An iterator that yields the StreamingOutput instances. """ import openai messages = self.type_adapter.format_input(model_input) response_format = self.type_adapter.format_output_type(output_type) + tools = self.type_adapter.format_tools(tools) if "model" not in inference_kwargs and self.model_name is not None: inference_kwargs["model"] = self.model_name + if tools: + inference_kwargs["tools"] = tools try: stream = await self.client.chat.completions.create( @@ -520,9 +691,96 @@ async def generate_stream( # type: ignore else: raise e + async for output in self._handle_streaming_response(stream): + yield output + + async def _handle_streaming_response( + self, stream: AsyncIterator["ChatCompletionChunk"] + ) -> AsyncIterator[StreamingOutput]: + """Handle streaming response from OpenAI API. + + Parameters + ---------- + stream + The streaming response from OpenAI API. + + Yields + ------ + AsyncIterator[StreamingOutput] + An iterator that yields the StreamingOutput instances. + """ + # This is needed as OpenAI provides the tool name and call id only in + # the first delta for each tool call. + tool_name = None + tool_call_id = None + async for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content is not None: - yield chunk.choices[0].delta.content + delta = chunk.choices[0].delta + if delta.tool_calls: + # When using streaming, only one tool call is returned at a + # time. + if delta.tool_calls[0].function.name: + tool_name = delta.tool_calls[0].function.name + if delta.tool_calls[0].id: + tool_call_id = delta.tool_calls[0].id + yield StreamingOutput( + content=delta.content, + tool_calls=[ + StreamingToolCallOutput( + name=tool_name or "", + args=delta.tool_calls[0].function.arguments, + id=tool_call_id + ) + ], + ) + elif delta.content is not None: + yield StreamingOutput(content=delta.content) + + +def _handle_openai_response( + response: "ChatCompletion" +) -> Output | List[Output]: + """Convert the response from the OpenAI API to an Output or a + list of Outputs. + + Parameters + ---------- + response + The response from the OpenAI API. + + Returns + ------- + Output | List[Output] + The Output or list of Outputs. + + """ + messages = [choice.message for choice in response.choices] + + outputs = [] + for message in messages: + if message.refusal is not None: + raise ValueError( + f"OpenAI refused to answer the request: {message.refusal}" + ) + if message.tool_calls: + outputs.append(Output( + content=message.content, + tool_calls=[ + ToolCallOutput( + name=tool.function.name, + args=ast.literal_eval(tool.function.arguments), + id=tool.id + ) + for tool in message.tool_calls + ], + )) + else: + outputs.append(Output(content=message.content)) + + if len(outputs) == 1: + return outputs[0] + else: + return outputs def from_openai( @@ -560,5 +818,5 @@ def from_openai( else: raise ValueError( "Invalid client type. The client must be an instance of " - "+ `openai.OpenAI` or `openai.AsyncOpenAI`." + + "`openai.OpenAI` or `openai.AsyncOpenAI`." ) diff --git a/outlines/models/sglang.py b/outlines/models/sglang.py index 2a8366d79..949fbf281 100644 --- a/outlines/models/sglang.py +++ b/outlines/models/sglang.py @@ -3,12 +3,14 @@ import json import warnings from typing import ( - TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union + TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union ) from outlines.inputs import Chat from outlines.models.base import AsyncModel, Model, ModelTypeAdapter from outlines.models.openai import OpenAITypeAdapter +from outlines.outputs import Output, StreamingOutput +from outlines.tools import ToolDef from outlines.types.dsl import ( CFG, JsonSchema, @@ -44,7 +46,7 @@ def format_input(self, model_input: Union[Chat, list, str]) -> list: """ return OpenAITypeAdapter().format_input(model_input) - def format_output_type(self, output_type: Optional[Any] = None) -> dict: + def format_output_type(self, output_type: Optional[Any]) -> dict: """Generate the structured output argument to pass to the client. Parameters @@ -78,6 +80,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict: else: return {"extra_body": {"regex": to_regex(term)}} + def format_tools(self, tools): + """Not available for SGLang.""" + if tools: + raise NotImplementedError( + "Tools are not available for SGLang." + ) + class SGLang(Model): """Thin wrapper around the `openai.OpenAI` client used to communicate with @@ -106,9 +115,10 @@ def __init__(self, client, model_name: Optional[str] = None): def generate( self, model_input: Union[Chat, list, str], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, list[str]]: + ) -> Output | list[Output]: """Generate text using SGLang. Parameters @@ -119,15 +129,18 @@ def generate( The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Union[str, list[str]] + Output | list[Output] The text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, @@ -145,14 +158,15 @@ def generate( ) if len(messages) == 1: - return messages[0].content + return Output(content=messages[0].content) else: - return [message.content for message in messages] + return [Output(content=message.content) for message in messages] def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError( @@ -162,9 +176,10 @@ def generate_batch( def generate_stream( self, model_input: Union[Chat, list, str], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using SGLang. Parameters @@ -175,15 +190,18 @@ def generate_stream( The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] + Iterator[StreamingOutput] An iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) @@ -194,12 +212,12 @@ def generate_stream( for chunk in stream: # pragma: no cover if chunk.choices and chunk.choices[0].delta.content is not None: - yield chunk.choices[0].delta.content + yield StreamingOutput(content=chunk.choices[0].delta.content) def _build_client_args( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], **inference_kwargs: Any, ) -> dict: """Build the arguments to pass to the SGLang client.""" @@ -250,9 +268,10 @@ def __init__(self, client, model_name: Optional[str] = None): async def generate( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, list[str]]: + ) -> Union[Output, list[Output]]: """Generate text using `sglang`. Parameters @@ -263,15 +282,18 @@ async def generate( The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Union[str, list[str]] + Union[Output, list[Output]] The text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) @@ -287,14 +309,15 @@ async def generate( ) if len(messages) == 1: - return messages[0].content + return Output(content=messages[0].content) else: - return [message.content for message in messages] + return [Output(content=message.content) for message in messages] async def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError( @@ -304,9 +327,10 @@ async def generate_batch( async def generate_stream( # type: ignore self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[StreamingOutput]: """Return a text generator. Parameters @@ -317,15 +341,18 @@ async def generate_stream( # type: ignore The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - AsyncIterator[str] + AsyncIterator[StreamingOutput] An async iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) @@ -337,12 +364,12 @@ async def generate_stream( # type: ignore async for chunk in stream: # pragma: no cover if chunk.choices and chunk.choices[0].delta.content is not None: - yield chunk.choices[0].delta.content + yield StreamingOutput(content=chunk.choices[0].delta.content) def _build_client_args( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], **inference_kwargs: Any, ) -> dict: """Build the arguments to pass to the SGLang client.""" diff --git a/outlines/models/tgi.py b/outlines/models/tgi.py index 31e9f5c20..c1d161878 100644 --- a/outlines/models/tgi.py +++ b/outlines/models/tgi.py @@ -7,11 +7,14 @@ Any, AsyncIterator, Iterator, + List, Optional, Union, ) -from outlines.models.base import AsyncModel,Model, ModelTypeAdapter +from outlines.models.base import AsyncModel, Model, ModelTypeAdapter +from outlines.outputs import Output, StreamingOutput +from outlines.tools import ToolDef from outlines.types.dsl import python_types_to_terms, to_regex, JsonSchema, CFG if TYPE_CHECKING: @@ -47,7 +50,7 @@ def format_input(self, model_input): def format_str_input(self, model_input: str) -> str: return model_input - def format_output_type(self, output_type: Optional[Any] = None) -> dict: + def format_output_type(self, output_type: Optional[Any]) -> dict: """Generate the structured output argument to pass to the client. Argument @@ -84,6 +87,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict: } } + def format_tools(self, tools): + """Not available for TGI.""" + if tools: + raise NotImplementedError( + "Tools are not available for TGI." + ) + class TGI(Model): """Thin wrapper around a `huggingface_hub.InferenceClient` client used to @@ -109,9 +119,10 @@ def __init__(self, client): def generate( self, model_input: str, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> str: + ) -> Output: """Generate text using TGI. Parameters @@ -122,27 +133,33 @@ def generate( The desired format of the response generated by the model. All output types except `CFG` are supported provided your server uses a backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - str + Output The text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) - return self.client.text_generation(**client_args) + response = self.client.text_generation(**client_args) + + return Output(content=response) def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError("TGI does not support batch inference.") @@ -150,9 +167,10 @@ def generate_batch( def generate_stream( self, model_input: str, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using TGI. Parameters @@ -163,15 +181,18 @@ def generate_stream( The desired format of the response generated by the model. All output types except `CFG` are supported provided your server uses a backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] + Iterator[StreamingOutput] An iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) @@ -181,12 +202,12 @@ def generate_stream( ) for chunk in stream: # pragma: no cover - yield chunk + yield StreamingOutput(content=chunk) def _build_client_args( self, model_input: str, - output_type: Optional[Any] = None, + output_type: Optional[Any], **inference_kwargs: Any, ) -> dict: """Build the arguments to pass to the TGI client.""" @@ -226,9 +247,10 @@ def __init__(self, client): async def generate( self, model_input: str, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> str: + ) -> Output: """Generate text using TGI. Parameters @@ -239,27 +261,31 @@ async def generate( The desired format of the response generated by the model. All output types except `CFG` are supported provided your server uses a backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - str + Output The text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) response = await self.client.text_generation(**client_args) - return response + return Output(content=response) async def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError("TGI does not support batch inference.") @@ -267,9 +293,10 @@ async def generate_batch( async def generate_stream( # type: ignore self, model_input: str, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[StreamingOutput]: """Stream text using TGI. Parameters @@ -280,15 +307,18 @@ async def generate_stream( # type: ignore The desired format of the response generated by the model. All output types except `CFG` are supported provided your server uses a backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - AsyncIterator[str] + AsyncIterator[StreamingOutput] An async iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) @@ -298,12 +328,12 @@ async def generate_stream( # type: ignore ) async for chunk in stream: # pragma: no cover - yield chunk + yield StreamingOutput(content=chunk) def _build_client_args( self, model_input: str, - output_type: Optional[Any] = None, + output_type: Optional[Any], **inference_kwargs: Any, ) -> dict: """Build the arguments to pass to the TGI client.""" diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 964dbce60..af0cea030 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -9,7 +9,9 @@ from outlines.inputs import Audio, Chat, Image, Video from outlines.models.base import Model, ModelTypeAdapter from outlines.models.tokenizer import Tokenizer +from outlines.outputs import Output from outlines.processors import OutlinesLogitsProcessor +from outlines.tools import ToolDef if TYPE_CHECKING: import torch @@ -173,7 +175,7 @@ def format_chat_input(self, model_input: Chat) -> str: def format_output_type( self, - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], ) -> Optional["LogitsProcessorList"]: """Generate the logits processor argument to pass to the model. @@ -194,6 +196,13 @@ def format_output_type( return LogitsProcessorList([output_type]) return None + def format_tools(self, tools): + """Not available for Transformers.""" + if tools: + raise NotImplementedError( + "Transformers does not support tools." + ) + class Transformers(Model): """Thin wrapper around a `transformers` model and a `transformers` @@ -295,9 +304,10 @@ def _prepare_model_inputs( def generate( self, model_input: Union[str, dict, Chat], - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, List[str]]: + ) -> Output | List[Output]: """Generate text using `transformers`. Parameters @@ -310,16 +320,19 @@ def generate( output_type The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the `generate` method of the `transformers` model. Returns ------- - Union[str, List[str]] + Output | List[Output] The text generated by the model. """ + self.type_adapter.format_tools(tools) prompts, inputs = self._prepare_model_inputs(model_input, False) logits_processor = self.type_adapter.format_output_type(output_type) @@ -336,15 +349,39 @@ def generate( if num_samples == 1 and len(generated_ids.shape) == 2: generated_ids = generated_ids.squeeze(0) - return self._decode_generation(generated_ids) + generated_text = self._decode_generation(generated_ids) + + if isinstance(generated_text, list): + return [Output(content=text) for text in generated_text] + return Output(content=generated_text) def generate_batch( self, model_input: List[Union[str, dict, Chat]], - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> List[Union[str, List[str]]]: - """""" + ) -> List[Output] | List[List[Output]]: + """Generate a batch of completions using `transformers`. + + Parameters + ---------- + model_input + The list of prompts based on which the model will generate a response. + output_type + The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. + **inference_kwargs + Additional keyword arguments to pass to the `generate` method of the `transformers` model. + + Returns + ------- + List[Output] | List[List[Output]] + The list of text generated by the model. + + """ + self.type_adapter.format_tools(tools) prompts, inputs = self._prepare_model_inputs(model_input, True) # type: ignore logits_processor = self.type_adapter.format_output_type(output_type) @@ -357,7 +394,17 @@ def generate_batch( if num_samples > 1: generated_ids = generated_ids.view(len(model_input), num_samples, -1) - return self._decode_generation(generated_ids) + generated_text = self._decode_generation(generated_ids) + + return [ # type: ignore + [ + Output(content=text) + for text in batch + ] + if isinstance(batch, list) + else Output(content=batch) + for batch in generated_text + ] def generate_stream(self, model_input, output_type, **inference_kwargs): """Not available for `transformers` models. @@ -369,7 +416,7 @@ def generate_stream(self, model_input, output_type, **inference_kwargs): "Streaming is not implemented for Transformers models." ) - def _generate_output_seq(self, prompts, inputs, **inference_kwargs): + def _generate_output_seq(self, prompts, inputs, **inference_kwargs): # type: ignore input_ids = inputs["input_ids"] output_ids = self.model.generate( @@ -472,7 +519,7 @@ def format_chat_input(self, model_input: Chat) -> dict: "content": message["content"][0], }) else: - messages_without_images.append(message) + messages_without_images.append(message) # type: ignore formatted_prompt = self.tokenizer.apply_chat_template( messages_without_images, tokenize=False @@ -513,7 +560,7 @@ def format_list_input(self, model_input: list) -> dict: def format_output_type( self, - output_type: Optional[OutlinesLogitsProcessor] = None, + output_type: Optional[OutlinesLogitsProcessor], ) -> Optional["LogitsProcessorList"]: """Generate the logits processor argument to pass to the model. @@ -534,6 +581,13 @@ def format_output_type( return LogitsProcessorList([output_type]) return None + def format_tools(self, tools): + """Not available for TransformersMultiModal.""" + if tools: + raise NotImplementedError( + "TransformersMultiModal does not support tools." + ) + class TransformersMultiModal(Transformers): """Thin wrapper around a `transformers` model and a `transformers` diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 1284ab336..4b975e9df 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -1,11 +1,21 @@ """Integration with a vLLM server.""" import json -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Iterator, + List, + Optional, + Union, +) from outlines.inputs import Chat from outlines.models.base import AsyncModel,Model, ModelTypeAdapter from outlines.models.openai import OpenAITypeAdapter +from outlines.outputs import Output, StreamingOutput +from outlines.tools import ToolDef from outlines.types.dsl import CFG, JsonSchema, python_types_to_terms, to_regex if TYPE_CHECKING: @@ -36,7 +46,7 @@ def format_input(self, model_input: Union[Chat, str, list]) -> list: """ return OpenAITypeAdapter().format_input(model_input) - def format_output_type(self, output_type: Optional[Any] = None) -> dict: + def format_output_type(self, output_type: Optional[Any]) -> dict: """Generate the structured output argument to pass to the client. Parameters @@ -64,6 +74,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict: else: return {"guided_regex": to_regex(term)} + def format_tools(self, tools): + """Not available for VLLM.""" + if tools: + raise NotImplementedError( + "Tools are not available for VLLM." + ) + class VLLM(Model): """Thin wrapper around the `openai.OpenAI` client used to communicate with @@ -93,9 +110,10 @@ def __init__( def generate( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, list[str]]: + ) -> Union[Output, list[Output]]: """Generate text using vLLM. Parameters @@ -106,15 +124,18 @@ def generate( The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Union[str, list[str]] + Union[Output, list[Output]] The text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, @@ -132,14 +153,15 @@ def generate( ) if len(messages) == 1: - return messages[0].content + return Output(content=messages[0].content) else: - return [message.content for message in messages] + return [Output(content=message.content) for message in messages] def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError("VLLM does not support batch inference.") @@ -147,9 +169,10 @@ def generate_batch( def generate_stream( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Iterator[str]: + ) -> Iterator[StreamingOutput]: """Stream text using vLLM. Parameters @@ -160,15 +183,18 @@ def generate_stream( The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Iterator[str] + Iterator[StreamingOutput] An iterator that yields the text generated by the model. """ + self.type_adapter.format_tools(tools) client_args = self._build_client_args( model_input, output_type, **inference_kwargs, ) @@ -179,12 +205,12 @@ def generate_stream( for chunk in stream: # pragma: no cover if chunk.choices and chunk.choices[0].delta.content is not None: - yield chunk.choices[0].delta.content + yield StreamingOutput(content=chunk.choices[0].delta.content) def _build_client_args( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], **inference_kwargs: Any, ) -> dict: """Build the arguments to pass to the OpenAI client.""" @@ -234,9 +260,10 @@ def __init__( async def generate( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, list[str]]: + ) -> Union[Output, list[Output]]: """Generate text using vLLM. Parameters @@ -247,12 +274,14 @@ async def generate( The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - Union[str, list[str]] + Union[Output, list[Output]] The text generated by the model. """ @@ -271,14 +300,15 @@ async def generate( ) if len(messages) == 1: - return messages[0].content + return Output(content=messages[0].content) else: - return [message.content for message in messages] + return [Output(content=message.content) for message in messages] async def generate_batch( self, model_input, - output_type = None, + output_type, + tools, **inference_kwargs, ): raise NotImplementedError("VLLM does not support batch inference.") @@ -286,9 +316,10 @@ async def generate_batch( async def generate_stream( # type: ignore self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> AsyncIterator[str]: + ) -> AsyncIterator[StreamingOutput]: """Stream text using vLLM. Parameters @@ -299,13 +330,16 @@ async def generate_stream( # type: ignore The desired format of the response generated by the model. All output types available in Outlines are supported provided your server uses a structured generation backend that supports them. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the client. Returns ------- - AsyncIterator[str] + AsyncIterator[StreamingOutput] An async iterator that yields the text generated by the model. + """ client_args = self._build_client_args( model_input, output_type, **inference_kwargs, @@ -318,12 +352,12 @@ async def generate_stream( # type: ignore async for chunk in stream: # pragma: no cover if chunk.choices and chunk.choices[0].delta.content is not None: - yield chunk.choices[0].delta.content + yield StreamingOutput(content=chunk.choices[0].delta.content) def _build_client_args( self, model_input: Union[Chat, str, list], - output_type: Optional[Any] = None, + output_type: Optional[Any], **inference_kwargs: Any, ) -> dict: """Build the arguments to pass to the OpenAI client.""" diff --git a/outlines/models/vllm_offline.py b/outlines/models/vllm_offline.py index 5e38e72f4..cad463075 100644 --- a/outlines/models/vllm_offline.py +++ b/outlines/models/vllm_offline.py @@ -7,6 +7,8 @@ from outlines.inputs import Chat from outlines.models.base import Model, ModelTypeAdapter from outlines.models.openai import OpenAITypeAdapter +from outlines.outputs import Output +from outlines.tools import ToolDef from outlines.types.dsl import CFG, JsonSchema, python_types_to_terms, to_regex if TYPE_CHECKING: @@ -56,7 +58,7 @@ def format_input_chat(self, model_input: Chat) -> list: ) return OpenAITypeAdapter().format_input(model_input) - def format_output_type(self, output_type: Optional[Any] = None) -> dict: + def format_output_type(self, output_type: Optional[Any]) -> dict: """Generate the structured output argument to pass to the model. For vLLM, the structured output definition is set in the @@ -90,6 +92,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict: else: return {"regex": to_regex(term)} + def format_tools(self, tools): + """Not available for VLLM offline.""" + if tools: + raise NotImplementedError( + "Tools are not available for VLLM offline." + ) + class VLLMOffline(Model): """Thin wrapper around a `vllm.LLM` model. @@ -114,7 +123,7 @@ def __init__(self, model: "LLM"): def _build_generation_args( self, inference_kwargs: dict, - output_type: Optional[Any] = None, + output_type: Optional[Any], ) -> "SamplingParams": """Create the `SamplingParams` object to pass to the `generate` method of the `vllm.LLM` model.""" @@ -134,9 +143,10 @@ def _build_generation_args( def generate( self, model_input: Chat | str, - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[str, List[str]]: + ) -> Union[Output, List[Output]]: """Generate text using vLLM offline. Parameters @@ -146,16 +156,19 @@ def generate( output_type The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the `generate` method in the `vllm.LLM` model. Returns ------- - Union[str, List[str]] + Union[Output, List[Output]] The text generated by the model. """ + self.type_adapter.format_tools(tools) sampling_params = self._build_generation_args( inference_kwargs, output_type, @@ -168,7 +181,7 @@ def generate( **inference_kwargs, ) else: - results = self.model.generate( + results = self.model( prompts=self.type_adapter.format_input(model_input), sampling_params=sampling_params, **inference_kwargs, @@ -176,16 +189,17 @@ def generate( results = [completion.text for completion in results[0].outputs] if len(results) == 1: - return results[0] + return Output(content=results[0]) else: - return results + return [Output(content=result) for result in results] def generate_batch( self, model_input: List[Chat | str], - output_type: Optional[Any] = None, + output_type: Optional[Any], + tools: Optional[List[ToolDef]], **inference_kwargs: Any, - ) -> Union[List[str], List[List[str]]]: + ) -> Union[List[Output], List[List[Output]]]: """Generate a batch of completions using vLLM offline. Parameters @@ -196,16 +210,19 @@ def generate_batch( output_type The logits processor the model will use to constrain the format of the generated text. + tools + The tools to use for the generation. inference_kwargs Additional keyword arguments to pass to the `generate` method in the `vllm.LLM` model. Returns ------- - Union[List[str], List[List[str]]] + Union[List[Output], List[List[Output]]] The text generated by the model. """ + self.type_adapter.format_tools(tools) sampling_params = self._build_generation_args( inference_kwargs, output_type, @@ -216,14 +233,20 @@ def generate_batch( "Batch generation is not available for the `Chat` input type." ) - results = self.model.generate( + results = self.model( prompts=[self.type_adapter.format_input(item) for item in model_input], sampling_params=sampling_params, **inference_kwargs, ) - return [[sample.text for sample in batch.outputs] for batch in results] - def generate_stream(self, model_input, output_type, **inference_kwargs): + return [ # type: ignore + [Output(content=sample.text) for sample in batch.outputs] + if len(batch.outputs) > 1 + else Output(content=batch.outputs[0].text) + for batch in results + ] + + def generate_stream(self, model_input, output_type, tools, **inference_kwargs): """Not available for `vllm.LLM`. TODO: Implement the streaming functionality ourselves. diff --git a/outlines/outputs.py b/outlines/outputs.py new file mode 100644 index 000000000..31862eb1e --- /dev/null +++ b/outlines/outputs.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Any, List, Optional + + +@dataclass +class ToolCallOutput: + """Contains the output of a tool call.""" + name: str + args: dict[str, Any] + id: Optional[str] = None + + +@dataclass +class StreamingToolCallOutput: + """Contains the output of a streaming tool call.""" + name: str + args: str + id: Optional[str] = None + + +@dataclass +class Output: + """Contains the output of a response from a model.""" + content: Optional[str] = None + tool_calls: Optional[List[ToolCallOutput]] = None + + def __str__(self) -> str: + """Return the content as a string.""" + return self.content or "" + + def __repr__(self) -> str: + """Return a string representation of the Output object.""" + return f"Output(content={self.content}, tool_calls={self.tool_calls})" + + def __add__(self, other) -> str: + """Support string concatenation with the content.""" + return str(self) + str(other) + + def __radd__(self, other) -> str: + """Support string concatenation with the content on the right side.""" + return str(other) + str(self) + + +@dataclass +class StreamingOutput: + """Contains the output of a streaming response from a model.""" + content: Optional[str] = None + tool_calls: Optional[List[StreamingToolCallOutput]] = None + + def __str__(self) -> str: + """Return the content as a string.""" + return self.content or "" + + def __repr__(self) -> str: + """Return a string representation of the Output object.""" + return f"StreamingOutput(content={self.content}, tool_calls={self.tool_calls})" + + def __add__(self, other) -> str: + """Support string concatenation with the content.""" + return str(self) + str(other) + + def __radd__(self, other) -> str: + """Support string concatenation with the content on the right side.""" + return str(other) + str(self) diff --git a/outlines/tools.py b/outlines/tools.py new file mode 100644 index 000000000..8b5f64e55 --- /dev/null +++ b/outlines/tools.py @@ -0,0 +1,203 @@ +import sys +import inspect +from typing import Any, Dict, List, Optional, Callable, Union, cast + +from pydantic import BaseModel + +from outlines.types.dsl import is_callable, is_dict_instance, is_pydantic_model + +if sys.version_info >= (3, 12): # pragma: no cover + from typing import TypedDict +else: # pragma: no cover + from typing_extensions import TypedDict + + +class ToolDef(TypedDict): + name: str + description: str + parameters: Dict[str, Dict[str, str]] + required: list[str] + + +ToolsInput = List[ToolDef | Callable | BaseModel] + + +def get_formatted_tools( + tools: Optional[ToolsInput] = None +) -> Optional[List[ToolDef]]: + """Convert a ToolsInput into a list of ToolDef instances. + + Parameters + ---------- + tools : Optional[ToolsInput] + List of tools to format. Can contain a list of ToolDef, Callable, or + BaseModel instances. + + Returns + ------- + Optional[List[ToolDef]] + List of ToolDef instances. If no tools are provided, returns `None`. + + """ + if not tools: + return None + + formatted_tools: List[ToolDef] = [] + + for tool in tools: + if is_dict_instance(tool): + tool_dict = cast(Dict[str, Any], tool) + if all( + key in tool_dict + for key in ["name", "description", "parameters", "required"] + ): + formatted_tools.append(cast(ToolDef, tool_dict)) + else: + missing_keys = ( + set(tool_dict.keys()) + - set(["name", "description", "parameters", "required"]) + ) + raise ValueError( + f"Invalid ToolDef: {tool}. " + + "Expected a dictionary with keys 'name', 'description', " + + "'parameters', and 'required'. " + + f"Missing keys: {missing_keys}" + ) + + elif is_callable(tool): + callable_tool = cast(Callable[..., Any], tool) + formatted_tools.append(_callable_to_tool_def(callable_tool)) + + elif is_pydantic_model(tool): + model_tool = cast(type[BaseModel], tool) + formatted_tools.append(_pydantic_model_to_tool_def(model_tool)) + + else: + raise ValueError( + f"Unsupported tool type: {type(tool)}. " + + "Expected ToolDef, callable, or Pydantic model." + ) + + return formatted_tools + + +def _callable_to_tool_def(func: Callable) -> ToolDef: + """Convert a callable to a ToolDef instance. + + Parameters + ---------- + func : Callable + The function to convert. + + Returns + ------- + ToolDef + ToolDef instance. + + """ + signature = inspect.signature(func) + name = func.__name__ + description = func.__doc__ or f"Function {name}" + parameters = {} + required = [] + + for param_name, param in signature.parameters.items(): + if param.annotation == inspect.Parameter.empty: + raise ValueError( + f"Parameter {param_name} has no annotation. " + + "All parameters must have an annotation." + ) + + param_type = _type_to_string(param.annotation) + parameters[param_name] = { + "type": param_type + } + + if param.default == inspect.Parameter.empty: + required.append(param_name) + + return { + "name": name, + "description": description, + "parameters": parameters, + "required": required + } + + +def _pydantic_model_to_tool_def(model_class: type[BaseModel]) -> ToolDef: + """Convert a Pydantic model to a ToolDef instance. + + Parameters + ---------- + model_class : type[BaseModel] + The Pydantic model class to convert. + + Returns + ------- + ToolDef + ToolDef instance with extracted model information. + + """ + schema = model_class.model_json_schema() + name = schema.get("title", model_class.__name__) + description = schema.get("description", f"Model {name}") + properties = schema.get("properties", {}) + required = schema.get("required", []) + + parameters = {} + for prop_name, prop_schema in properties.items(): + prop_type = prop_schema.get("type", "string") + parameters[prop_name] = { + "type": prop_type + } + + return { + "name": name, + "description": description, + "parameters": parameters, + "required": required + } + + +def _type_to_string(type_annotation: Any) -> str: + """Convert a Python type annotation to a string representation. + + Parameters + ---------- + type_annotation : Any + The type annotation to convert. + + Returns + ------- + str + String representation of the type. + + """ + # Handle Union types (including Optional) + if ( + hasattr(type_annotation, '__origin__') and + type_annotation.__origin__ is Union + ): + args = type_annotation.__args__ + union_types = [] + for arg in args: + if arg is type(None): + union_types.append("null") + else: + union_types.append(_type_to_string(arg)) + return "|".join(union_types) + + if type_annotation is str: + return "string" + elif type_annotation is int: + return "integer" + elif type_annotation is float: + return "number" + elif type_annotation is bool: + return "boolean" + elif type_annotation is list: + return "array" + elif type_annotation is dict: + return "object" + else: + raise ValueError(f"Unsupported type: {type_annotation}") diff --git a/tests/backends/test_llguidance.py b/tests/backends/test_llguidance.py index fb5faea91..bf06be6e3 100644 --- a/tests/backends/test_llguidance.py +++ b/tests/backends/test_llguidance.py @@ -158,15 +158,15 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_ assert isinstance(processor, LLGuidanceLogitsProcessor) generator = outlines.Generator(model, backend="llguidance", processor=processor) response = generator("Hello, how are you?") - assert response[0] == "{" + assert response.content[0] == "{" # regex processor = backend.get_regex_logits_processor(regex) assert isinstance(processor, LLGuidanceLogitsProcessor) generator = outlines.Generator(model, backend="llguidance", processor=processor) response = generator("Hello, how are you?") - assert len(response) == 3 - assert int(response) + assert len(response.content) == 3 + assert int(response.content) # cfg lark processor = backend.get_cfg_logits_processor(cfg_lark) @@ -174,11 +174,11 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_ generator = outlines.Generator(model, backend="llguidance", processor=processor) response = generator("Hello, how are you?") assert ( - "+" in response - or "-" in response - or "*" in response - or "/" in response - or float(response.strip()) + "+" in response.content + or "-" in response.content + or "*" in response.content + or "/" in response.content + or float(response.content.strip()) ) # cfg ebnf @@ -186,7 +186,7 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_ assert isinstance(processor, LLGuidanceLogitsProcessor) generator = outlines.Generator(model, backend="llguidance", processor=processor) response = generator("Hello, how are you?") - assert response == "yes" or response == "no" + assert response.content == "yes" or response.content == "no" # batch + multiple generations processor = backend.get_json_schema_logits_processor(json_schema) @@ -196,7 +196,7 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_ response = generator.batch(["Create a character", "Hello, how are you?"], max_new_tokens=200) assert len(response) == 2 for r in response: - assert r[0] == "{" + assert r.content[0] == "{" else: response = generator("Create a character", max_tokens=20) - assert response[0] == "{" + assert response.content[0] == "{" diff --git a/tests/backends/test_outlines_core.py b/tests/backends/test_outlines_core.py index ce3242457..0ec04c153 100644 --- a/tests/backends/test_outlines_core.py +++ b/tests/backends/test_outlines_core.py @@ -149,15 +149,15 @@ def test_outlines_core_backend(model, tensor_library_name, json_schema, regex, c assert isinstance(processor, OutlinesCoreLogitsProcessor) generator = outlines.Generator(model, backend="outlines_core", processor=processor) response = generator("Hello, how are you?") - assert "name" in response + assert "name" in response.content # regex processor = backend.get_regex_logits_processor(regex) assert isinstance(processor, OutlinesCoreLogitsProcessor) generator = outlines.Generator(model, backend="outlines_core", processor=processor) response = generator("Hello, how are you?") - assert len(response) == 3 - assert int(response) + assert len(response.content) == 3 + assert int(response.content) # cfg with pytest.raises( @@ -174,9 +174,9 @@ def test_outlines_core_backend(model, tensor_library_name, json_schema, regex, c response = generator.batch(["Create a character", "Hello, how are you?"], max_new_tokens=200) assert len(response) == 2 for r in response: - assert r[0] == "{" - assert "name" in r + assert r.content[0] == "{" + assert "name" in r.content else: response = generator("Create a character", max_tokens=20) - assert response[0] == "{" - assert "name" in response + assert response.content[0] == "{" + assert "name" in response.content diff --git a/tests/backends/test_xgrammar.py b/tests/backends/test_xgrammar.py index e25f66508..14dc837c3 100644 --- a/tests/backends/test_xgrammar.py +++ b/tests/backends/test_xgrammar.py @@ -125,23 +125,23 @@ def test_xgrammar_backend(model, tensor_library_name, json_schema, regex, cfg): assert isinstance(processor, XGrammarLogitsProcessor) generator = outlines.Generator(model, backend="xgrammar", processor=processor) response = generator("Hello, how are you?") - assert response[0] == "{" - assert "name" in response + assert response.content[0] == "{" + assert "name" in response.content # regex processor = backend.get_regex_logits_processor(regex) assert isinstance(processor, XGrammarLogitsProcessor) generator = outlines.Generator(model, backend="xgrammar", processor=processor) response = generator("Hello, how are you?") - assert len(response) == 3 - assert int(response) + assert len(response.content) == 3 + assert int(response.content) # cfg processor = backend.get_cfg_logits_processor(cfg) assert isinstance(processor, XGrammarLogitsProcessor) generator = outlines.Generator(model, backend="xgrammar", processor=processor) response = generator("Hello, how are you?") - assert response == "yes" or response == "no" + assert response.content == "yes" or response.content == "no" # batch + multiple generations processor = backend.get_json_schema_logits_processor(json_schema) @@ -151,12 +151,12 @@ def test_xgrammar_backend(model, tensor_library_name, json_schema, regex, cfg): response = generator.batch(["Create a character", "Hello, how are you?"], max_new_tokens=200) assert len(response) == 2 for r in response: - assert r[0] == "{" - assert "name" in r + assert r.content[0] == "{" + assert "name" in r.content else: response = generator("Create a character", max_tokens=20) - assert response[0] == "{" - assert "name" in response + assert response.content[0] == "{" + assert "name" in response.content def test_xgrammar_backend_invalid_model(): diff --git a/tests/models/test_anthopic_type_adapter.py b/tests/models/test_anthopic_type_adapter.py deleted file mode 100644 index 0be6e0ebf..000000000 --- a/tests/models/test_anthopic_type_adapter.py +++ /dev/null @@ -1,118 +0,0 @@ -import io -import pytest -from dataclasses import dataclass - -from PIL import Image as PILImage -from outlines.inputs import Chat, Image -from outlines.models.anthropic import AnthropicTypeAdapter - - -@pytest.fixture -def image(): - width, height = 1, 1 - white_background = (255, 255, 255) - image = PILImage.new("RGB", (width, height), white_background) - - # Save to an in-memory bytes buffer and read as png - buffer = io.BytesIO() - image.save(buffer, format="PNG") - buffer.seek(0) - image = PILImage.open(buffer) - - return image - - -@pytest.fixture -def adapter(): - return AnthropicTypeAdapter() - - -def test_anthropic_type_adapter_input_text(adapter): - message = "prompt" - result = adapter.format_input(message) - assert result == {"messages": [{"role": "user", "content": message}]} - - -def test_anthropic_type_adapter_input_vision(adapter, image): - image_input = Image(image) - text_input = "hello" - result = adapter.format_input([text_input, image_input]) - assert result == { - "messages": [ - { - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": image_input.image_str, - }, - }, - {"type": "text", "text": text_input}, - ], - }, - ] - } - - -def test_anthropic_type_adapter_input_chat(adapter, image): - image_input = Image(image) - model_input = Chat(messages=[ - {"role": "system", "content": "prompt"}, - {"role": "user", "content": [ - "hello", - image_input, - ]}, - {"role": "assistant", "content": "response"}, - ]) - result = adapter.format_input(model_input) - assert result == { - "messages": [ - {"role": "system", "content": "prompt"}, - {"role": "user", "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": image_input.image_str, - }, - }, - {"type": "text", "text": "hello"}, - ]}, - {"role": "assistant", "content": "response"}, - ] - } - - -def test_anthropic_type_adapter_input_invalid(adapter): - @dataclass - class Audio: - file: str - - with pytest.raises(TypeError, match="is not available with Anthropic"): - _ = adapter.format_input(Audio("file")) - - with pytest.raises( - ValueError, - match="All assets provided must be of type Image", - ): - _ = adapter.format_input(["prompt", Audio("file")]) - - with pytest.raises( - ValueError, - match="The content must be a string or a list", - ): - _ = adapter.format_input( - Chat(messages=[{"role": "user", "content": {"foo": "bar"}}]) - ) - - -def test_anthropic_type_adapter_output(adapter): - with pytest.raises( - NotImplementedError, - match="is not available with Anthropic" - ): - adapter.format_output_type(str) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index fe9b976cc..127b807c3 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1,13 +1,20 @@ import io from typing import Generator +import pytest from anthropic import Anthropic as AnthropicClient from PIL import Image as PILImage -import pytest import outlines -from outlines.inputs import Chat, Image, Video +from outlines.inputs import Chat, Image from outlines.models.anthropic import Anthropic +from outlines.outputs import ( + Output, + StreamingOutput, + StreamingToolCallOutput, + ToolCallOutput, +) +from outlines.tools import ToolDef MODEL_NAME = "claude-3-haiku-20240307" @@ -38,6 +45,27 @@ def image(): return image +@pytest.fixture +def tools(): + return [ + ToolDef( + name="get_weather", + description="Get the current weather for a given city", + parameters={"city": {"type": "string"}}, + required=["city"], + ), + ToolDef( + name="get_user_info", + description="Get the current user info", + parameters={ + "first_name": {"type": "string"}, + "last_name": {"type": "string"} + }, + required=["last_name"], + ), + ] + + def test_init_from_client(): client = AnthropicClient() @@ -57,83 +85,124 @@ def test_init_from_client(): def test_anthropic_wrong_inference_parameters(): with pytest.raises(TypeError, match="got an unexpected"): model = Anthropic(AnthropicClient(), MODEL_NAME) - model.generate("prompt", foo=10, max_tokens=1024) - - -def test_anthropic_wrong_input_type(image): - class Foo: - def __init__(self, foo): - self.foo = foo - - with pytest.raises(TypeError, match="is not available"): - model = Anthropic(AnthropicClient(), MODEL_NAME) - model.generate(Foo("prompt")) - - with pytest.raises(ValueError, match="All assets provided must be of type Image"): - model.generate(["foo?", Image(image), Video("")]) - - -def test_anthropic_wrong_output_type(): - class Foo: - def __init__(self, foo): - self.foo = foo - - with pytest.raises(NotImplementedError, match="is not available"): - model = Anthropic(AnthropicClient(), MODEL_NAME) - model.generate("prompt", Foo(1)) + model("prompt", foo=10, max_tokens=1) @pytest.mark.api_call -def test_anthropic_simple_call(model): - result = model.generate("Respond with one word. Not more.", max_tokens=1024) - assert isinstance(result, str) +def test_anthropic_call(model): + result = model("Respond with one word. Not more.", max_tokens=100) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.xfail(reason="Anthropic requires the `max_tokens` parameter to be set") @pytest.mark.api_call -def test_anthropic_direct_call(model_no_model_name): +def test_anthropic_call_no_max_tokens(model_no_model_name): result = model_no_model_name( "Respond with one word. Not more.", model_name=MODEL_NAME, - max_tokens=1024, + max_tokens=100, ) - assert isinstance(result, str) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call -def test_anthropic_simple_vision(model, image): - result = model.generate( +def test_anthropic_vision(model, image): + result = model( [ "What does this logo represent?", Image(image), ], - max_tokens=1024, + max_tokens=100, ) - assert isinstance(result, str) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call def test_anthropic_chat(model, image): - result = model.generate(Chat(messages=[ + result = model(Chat(messages=[ {"role": "assistant", "content": "How can I help you today?"}, { "role": "user", "content": ["What does this logo represent?", Image(image)] }, - ]), max_tokens=10) - assert isinstance(result, str) + ]), max_tokens=100) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call -def test_anthopic_streaming(model): - result = model.stream("Respond with one word. Not more.", max_tokens=1024) +def test_anthropic_tools(model, tools): + result = model( + "What is the weather in Tokyo?", + tools=tools, + max_tokens=100, + ) + assert isinstance(result, Output) + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert isinstance(tool_call, ToolCallOutput) + assert tool_call.name == "get_weather" + assert tool_call.args == {"city": "Tokyo"} + assert tool_call.id is not None + + +@pytest.mark.api_call +def test_anthropic_tools_chat(model, tools): + chat = Chat( + messages=[ + {"role": "user", "content": "What is the weather in Tokyo?"}, + ], + ) + generator = outlines.Generator(model, tools=tools) + result = generator(chat, max_tokens=100) + chat.add_output(result) + chat.add_tool_message( + tool_call_id=result.tool_calls[0].id, + content="The weather in Tokyo is sunny.", + ) + chat.add_user_message("Is it a good weather to go out?") + result = generator(chat, max_tokens=100) + assert isinstance(result, Output) + assert isinstance(result.content, str) + + +@pytest.mark.api_call +def test_anthropic_streaming(model): + result = model.stream("Respond with one sentence. Not more.", max_tokens=100) + assert isinstance(result, Generator) + for chunk in result: + assert isinstance(chunk, StreamingOutput) + assert isinstance(chunk.content, str) + + +@pytest.mark.api_call +def test_anthropic_streaming_tools(model, tools): + result = model.stream( + "What is the weather in Tokyo?", + tools=tools, + max_tokens=100, + ) assert isinstance(result, Generator) - assert isinstance(next(result), str) + for chunk in result: + assert isinstance(chunk, StreamingOutput) + if chunk.tool_calls is not None: + assert len(chunk.tool_calls) == 1 + tool_call = chunk.tool_calls[0] + assert isinstance(tool_call, StreamingToolCallOutput) + assert tool_call.name == "get_weather" + assert isinstance(tool_call.args, str) + assert tool_call.id is not None + else: + assert chunk.content is not None def test_anthropic_batch(model): with pytest.raises(NotImplementedError, match="does not support"): model.batch( ["Respond with one word.", "Respond with one word."], - max_tokens=1024, + max_tokens=1, ) diff --git a/tests/models/test_anthropic_type_adapter.py b/tests/models/test_anthropic_type_adapter.py new file mode 100644 index 000000000..5270893d0 --- /dev/null +++ b/tests/models/test_anthropic_type_adapter.py @@ -0,0 +1,243 @@ +import io +import pytest +from dataclasses import dataclass + +from PIL import Image as PILImage + +from outlines.inputs import Chat, Image +from outlines.models.anthropic import AnthropicTypeAdapter +from outlines.tools import ToolDef + + +@pytest.fixture +def image(): + width, height = 1, 1 + white_background = (255, 255, 255) + image = PILImage.new("RGB", (width, height), white_background) + + # Save to an in-memory bytes buffer and read as png + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + image = PILImage.open(buffer) + + return image + + +@pytest.fixture +def adapter(): + return AnthropicTypeAdapter() + + +def test_anthropic_type_adapter_input_text(adapter): + message = "prompt" + result = adapter.format_input(message) + assert result == {"messages": [ + {"role": "user", "content": [ + {"type": "text", "text": message} + ]} + ]} + + +def test_anthropic_type_adapter_input_vision(adapter, image): + image_input = Image(image) + text_input = "hello" + result = adapter.format_input([text_input, image_input]) + assert result == { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": text_input, + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_input.image_str, + }, + }, + ], + }, + ] + } + + +def test_anthropic_type_adapter_input_chat(adapter, image): + image_input = Image(image) + model_input = Chat(messages=[ + { + "role": "user", + "content": [ + "hello", + image_input, + ] + }, + { + "role": "assistant", + "tool_calls": [ + { + "tool_name": "tool_name", + "tool_call_id": "abc", + "args": {"foo": "bar"} + } + ] + }, + { + "role": "tool", + "content": "response", + "tool_call_id": "abc" + }, + {"role": "user", "content": "prompt"}, + ]) + result = adapter.format_input(model_input) + assert result == { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_input.image_str, + }, + }, + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "abc", + "name": "tool_name", + "input": {"foo": "bar"}, + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "abc", + "content": "response", + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "prompt", + } + ] + }, + ] + } + + +def test_anthropic_type_adapter_input_invalid(adapter): + @dataclass + class Audio: + file: str + + # Invalid input type + with pytest.raises(TypeError, match="is not available"): + _ = adapter.format_input(image) + + # Invalid type within list input + with pytest.raises( + ValueError, + match="All assets provided must be of type Image", + ): + _ = adapter.format_input(["prompt", Audio("file")]) + + # Chat message with system role + with pytest.raises(ValueError, match="System messages are not supported in Chat inputs for Anthropic"): + _ = adapter.format_input(Chat(messages=[{"role": "system", "content": "prompt"}])) + + # Chat message with invalid role + with pytest.raises(ValueError, match="Invalid message role"): + _ = adapter.format_input(Chat(messages=[{"content": "prompt"}])) + + # Chat message with invalid content type + with pytest.raises(ValueError, match="Invalid content type"): + _ = adapter.format_input(Chat(messages=[{"role": "user", "content": {"foo": "bar"}}])) + + # Chat message with user role and no content + with pytest.raises(ValueError, match="Content is required for user messages"): + _ = adapter.format_input(Chat(messages=[{"role": "user"}])) + + # Chat message with assistant role and neither content nor tool calls + with pytest.raises(ValueError, match="Either content or tool calls is required for assistant messages"): + _ = adapter.format_input(Chat(messages=[{"role": "assistant"}])) + + # Chat message with tool role and no content + with pytest.raises(ValueError, match="Content and tool call id are required for tool messages"): + _ = adapter.format_input(Chat(messages=[{"role": "tool", "tool_call_id": "abc"}])) + + # Chat message with tool role and no tool call id + with pytest.raises(ValueError, match="Content and tool call id are required for tool messages"): + _ = adapter.format_input(Chat(messages=[{"role": "tool", "content": "response"}])) + + +def test_anthropic_type_adapter_output(adapter): + with pytest.raises( + NotImplementedError, + match="is not available with Anthropic" + ): + adapter.format_output_type(str) + + +def test_anthropic_type_adapter_format_tools(adapter): + tools = [ + ToolDef( + name="tool_name", + description="tool_description", + parameters={"foo": {"type": "string"}}, + required=["foo"], + ), + ToolDef( + name="tool_name_2", + description="tool_description_2", + parameters={ + "foo": {"type": "string"}, + "bar": {"type": "integer"} + }, + required=["bar"], + ), + ] + result = adapter.format_tools(tools) + assert result == [ + { + "name": "tool_name", + "description": "tool_description", + "input_schema": { + "type": "object", + "properties": { + "foo": {"type": "string"} + }, + "required": ["foo"], + }, + }, + { + "name": "tool_name_2", + "description": "tool_description_2", + "input_schema": { + "type": "object", + "properties": { + "foo": {"type": "string"}, + "bar": {"type": "integer"} + }, + "required": ["bar"], + }, + }, + ] diff --git a/tests/models/test_dottxt.py b/tests/models/test_dottxt.py index df637cfe7..17a268144 100644 --- a/tests/models/test_dottxt.py +++ b/tests/models/test_dottxt.py @@ -8,6 +8,7 @@ import outlines from outlines import Generator from outlines.models.dottxt import Dottxt +from outlines.outputs import Output MODEL_NAME = "dottxt/dottxt-v1-alpha" @@ -99,7 +100,8 @@ def test_dottxt_wrong_inference_parameters(model_no_model_name): @pytest.mark.api_call def test_dottxt_direct_pydantic_call(model_no_model_name): result = model_no_model_name("Create a user", User) - assert "first_name" in json.loads(result) + assert isinstance(result, Output) + assert "first_name" in json.loads(result.content) @pytest.mark.api_call @@ -112,14 +114,16 @@ def test_dottxt_direct_jsonschema_call( model_name=model_name_and_revision[0], model_revision=model_name_and_revision[1], ) - assert "first_name" in json.loads(result) + assert isinstance(result, Output) + assert "first_name" in json.loads(result.content) @pytest.mark.api_call def test_dottxt_generator_pydantic_call(model): generator = Generator(model, User) result = generator("Create a user") - assert "first_name" in json.loads(result) + assert isinstance(result, Output) + assert "first_name" in json.loads(result.content) @pytest.mark.api_call diff --git a/tests/models/test_dottxt_type_adapter.py b/tests/models/test_dottxt_type_adapter.py index a0356502c..8e7b8ba04 100644 --- a/tests/models/test_dottxt_type_adapter.py +++ b/tests/models/test_dottxt_type_adapter.py @@ -10,6 +10,7 @@ from outlines.inputs import Image from outlines.models.dottxt import DottxtTypeAdapter +from outlines.tools import ToolDef from outlines.types import cfg, json_schema, regex if sys.version_info >= (3, 12): @@ -58,7 +59,7 @@ def test_dottxt_type_adapter_input_text(adapter): def test_dottxt_type_adapter_input_invalid(adapter, image): - prompt = ["prompt", image] + prompt = ["prompt", Image(image)] with pytest.raises(TypeError, match="The input type"): _ = adapter.format_input(prompt) @@ -135,3 +136,15 @@ def test_dottxt_type_adapter_json_schema_str(adapter, schema): def test_dottxt_type_adapter_json_schema_dict(adapter, schema): result = adapter.format_output_type(json_schema(schema)) assert result == json.dumps(schema) + + +def test_dottxt_type_adapter_tools(adapter): + with pytest.raises( + NotImplementedError, + match="Dottxt does not support tools." + ): + adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + adapter.format_tools(None) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 3582afa07..58493e37f 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -13,7 +13,9 @@ import outlines from outlines.inputs import Chat, Image, Video from outlines.models.gemini import Gemini +from outlines.outputs import Output, StreamingOutput from outlines.types import Choice +from outlines.tools import ToolDef if sys.version_info >= (3, 12): from typing import TypedDict @@ -48,6 +50,27 @@ def image(): return image +@pytest.fixture +def tools(): + return [ + ToolDef( + name="get_weather", + description="Get the current weather for a given city", + parameters={"city": {"type": "string"}}, + required=["city"], + ), + ToolDef( + name="get_user_info", + description="Get the current user info", + parameters={ + "first_name": {"type": "string"}, + "last_name": {"type": "string"} + }, + required=["last_name"], + ), + ] + + @pytest.mark.api_call def test_gemini_init_from_client(): client = Client() @@ -68,19 +91,20 @@ def test_gemini_init_from_client(): @pytest.mark.api_call def test_gemini_wrong_inference_parameters(model): with pytest.raises(ValidationError): - model.generate("prompt", foo=10) + model("prompt", foo=10) @pytest.mark.api_call def test_gemini_wrong_input_type(model, image): with pytest.raises(ValueError, match="All assets provided must be of type Image"): - model.generate(["foo?", Image(image), Video("")]) + model(["foo?", Image(image), Video("")]) @pytest.mark.api_call -def test_gemini_simple_call(model): - result = model.generate("Respond with one word. Not more.") - assert isinstance(result, str) +def test_gemini_call(model): + result = model("Respond with one word. Not more.") + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call @@ -89,45 +113,50 @@ def test_gemini_direct_call(model_no_model_name): "Respond with one word. Not more.", model=MODEL_NAME ) - assert isinstance(result, str) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call -def test_gemini_simple_vision(model, image): - result = model.generate(["What does this logo represent?", Image(image)]) - assert isinstance(result, str) +def test_gemini_vision(model, image): + result = model(["What does this logo represent?", Image(image)]) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call def test_gemini_chat(model, image): - result = model.generate(Chat(messages=[ + result = model(Chat(messages=[ {"role": "assistant", "content": "How can I help you today?"}, { "role": "user", "content": ["What does this logo represent?", Image(image)] }, ])) - assert isinstance(result, str) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call -def test_gemini_simple_pydantic(model): +def test_gemini_pydantic(model): class Foo(BaseModel): bar: int - result = model.generate("foo?", Foo) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = model("foo?", Foo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) @pytest.mark.api_call -def test_gemini_simple_vision_pydantic(model, image): +def test_gemini_vision_pydantic(model, image): class Logo(BaseModel): name: int - result = model.generate(["What does this logo represent?", Image(image)], Logo) - assert isinstance(result, str) - assert "name" in json.loads(result) + result = model(["What does this logo represent?", Image(image)], Logo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "name" in json.loads(result.content) @pytest.mark.api_call @@ -139,112 +168,177 @@ class Foo(BaseModel): sna: int bar: Bar - result = model.generate("foo?", Foo) - assert isinstance(result, str) - assert "sna" in json.loads(result) - assert "bar" in json.loads(result) - assert "fu" in json.loads(result)["bar"] + result = model("foo?", Foo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "sna" in json.loads(result.content) + assert "bar" in json.loads(result.content) + assert "fu" in json.loads(result.content)["bar"] @pytest.mark.xfail( reason="The Gemini SDK's serialization method does not support Json Schema strings." ) @pytest.mark.api_call -def test_gemini_simple_json_schema_string(model): +def test_gemini_json_schema_string(model): schema = "{'properties': {'bar': {'title': 'Bar', 'type': 'integer'}}, 'required': ['bar'], 'title': 'Foo', 'type': 'object'}" - result = model.generate("foo?", schema) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = model("foo?", schema) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) @pytest.mark.xfail( reason="The Gemini SDK's serialization method does not support Json Schema dictionaries." ) @pytest.mark.api_call -def test_gemini_simple_json_schema_dict(model): +def test_gemini_json_schema_dict(model): schema = { "properties": {"bar": {"type": "integer"}}, "required": ["bar"], "type": "object", } - result = model.generate("foo?", schema) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = model("foo?", schema) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) @pytest.mark.api_call -def test_gemini_simple_typed_dict(model): +def test_gemini_typed_dict(model): class Foo(TypedDict): bar: int - result = model.generate("foo?", Foo) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = model("foo?", Foo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) @pytest.mark.api_call -def test_gemini_simple_dataclass(model): +def test_gemini_dataclass(model): @dataclass class Foo: bar: int - result = model.generate("foo?", Foo) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = model("foo?", Foo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) @pytest.mark.api_call -def test_gemini_simple_choice_enum(model): +def test_gemini_choice_enum(model): class Foo(Enum): bar = "Bar" foor = "Foo" - result = model.generate("foo?", Foo) - assert isinstance(result, str) - assert result == "Foo" or result == "Bar" + result = model("foo?", Foo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert result.content == "Foo" or result.content == "Bar" @pytest.mark.api_call -def test_gemini_simple_choice_choice(model): - result = model.generate("foo?", Choice(["Foo", "Bar"])) - assert isinstance(result, str) - assert result == "Foo" or result == "Bar" +def test_gemini_choice_choice(model): + result = model("foo?", Choice(["Foo", "Bar"])) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert result.content == "Foo" or result.content == "Bar" @pytest.mark.api_call -def test_gemini_sample_choice_literal(model): - result = model.generate("foo?", Literal["Foo", "Bar"]) - assert isinstance(result, str) - assert result == "Foo" or result == "Bar" +def test_gemini_choice_literal(model): + result = model("foo?", Literal["Foo", "Bar"]) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert result.content == "Foo" or result.content == "Bar" @pytest.mark.xfail( reason="Gemini supports lists for choices but we do not as it is semantically incorrect." ) @pytest.mark.api_call -def test_gemini_simple_choice_list(model): +def test_gemini_choice_list(model): choices = ["Foo", "Bar"] - result = model.generate("foo?", choices) - assert isinstance(result, str) - assert result == "Foo" or result == "Bar" + result = model("foo?", choices) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert result.content == "Foo" or result.content == "Bar" @pytest.mark.api_call -def test_gemini_simple_list_pydantic(model): +def test_gemini_list_pydantic(model): class Foo(BaseModel): bar: int - result = model.generate("foo?", list[Foo]) - assert isinstance(json.loads(result), list) - assert isinstance(json.loads(result)[0], dict) - assert "bar" in json.loads(result)[0] + result = model("foo?", list[Foo]) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content)[0] + + +@pytest.mark.api_call +def test_gemini_tools(model, tools): + result = model( + "What is the weather in Tokyo?", + tools=tools, + ) + assert isinstance(result, Output) + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call.name == "get_weather" + assert tool_call.args == {"city": "Tokyo"} or {"city": "tokyo"} + assert tool_call.name is not None + + +@pytest.mark.api_call +def test_gemini_tools_chat(model, tools): + chat = Chat( + messages=[ + {"role": "user", "content": "What is the weather in Tokyo?"}, + ], + ) + generator = outlines.Generator(model, tools=tools) + result = generator(chat) + print("RESULT", result) + chat.add_output(result) + chat.add_tool_message( + tool_name=result.tool_calls[0].name, + content="The weather in Tokyo is sunny.", + ) + chat.add_user_message("Is it a good weather to go out?") + result = generator(chat) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call def test_gemini_streaming(model): result = model.stream("Respond with one word. Not more.") assert isinstance(result, Generator) - assert isinstance(next(result), str) + chunk = next(result) + assert isinstance(chunk, StreamingOutput) + assert isinstance(chunk.content, str) + + +@pytest.mark.api_call +def test_gemini_streaming_tools(model, tools): + result = model.stream( + "What is the weather in Tokyo?", + tools=tools, + ) + assert isinstance(result, Generator) + for chunk in result: + assert isinstance(chunk, StreamingOutput) + assert chunk.content is None + assert chunk.tool_calls is not None + assert len(chunk.tool_calls) == 1 + tool_call = chunk.tool_calls[0] + assert tool_call.name == "get_weather" + assert isinstance(tool_call.args, str) + assert tool_call.name is not None @pytest.mark.api_call diff --git a/tests/models/test_gemini_type_adapter.py b/tests/models/test_gemini_type_adapter.py index 0f1985991..4b9ecfb02 100644 --- a/tests/models/test_gemini_type_adapter.py +++ b/tests/models/test_gemini_type_adapter.py @@ -7,12 +7,12 @@ from PIL import Image as PILImage from genson import SchemaBuilder -from google.genai import types from pydantic import BaseModel from outlines import cfg, json_schema, regex from outlines.inputs import Chat, Image from outlines.models.gemini import GeminiTypeAdapter +from outlines.tools import ToolDef if sys.version_info >= (3, 12): from typing import TypedDict @@ -56,7 +56,7 @@ def adapter(): def test_gemini_type_adapter_input_text(adapter): message = "prompt" result = adapter.format_input(message) - assert result == {"contents": [{"text": message}]} + assert result == {"contents": [{"role": "user", "parts": [{"text": message}]}]} def test_gemini_type_adapter_input_vision(adapter, image): @@ -76,28 +76,46 @@ def test_gemini_type_adapter_input_vision(adapter, image): }, }, ], - }, + } ] } def test_gemini_type_adapter_input_chat(adapter, image): image_input = Image(image) - input_message = Chat(messages=[ - {"role": "assistant", "content": "How can I help you today?"}, - {"role": "user", "content": [ - "What does this logo represent?", - image_input, - ]}, + model_input = Chat(messages=[ + { + "role": "user", + "content": [ + "hello", + image_input, + ] + }, + { + "role": "assistant", + "tool_calls": [ + { + "tool_name": "tool_name", + "tool_call_id": "abc", + "args": {"foo": "bar"} + } + ] + }, + { + "role": "tool", + "content": "response", + "tool_name": "tool_name", + "tool_call_id": "abc" + }, + {"role": "user", "content": "prompt"}, ]) - result = adapter.format_input(input_message) + result = adapter.format_input(model_input) assert result == { "contents": [ - {"role": "model", "parts": [{"text": "How can I help you today?"}]}, { "role": "user", "parts": [ - {"text": "What does this logo represent?"}, + {"text": "hello"}, { "inline_data": { "mime_type": "image/png", @@ -106,20 +124,85 @@ def test_gemini_type_adapter_input_chat(adapter, image): }, ], }, + { + "role": "model", + "parts": [ + { + "function_call": { + "id": "abc", + "name": "tool_name", + "args": {"foo": "bar"}, + }, + } + ] + }, + { + "role": "user", + "parts": [ + { + "function_response": { + "id": "abc", + "name": "tool_name", + "response": "response", + }, + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "prompt", + } + ] + }, ] } -def test_gemini_type_adapter_input_invalid(adapter): +def test_gemini_type_adapter_input_invalid(adapter, image): @dataclass class Audio: file: str - prompt = Audio( - "file", - ) - with pytest.raises(TypeError, match="The input type"): - _ = adapter.format_input(prompt) + # Invalid input type + with pytest.raises(TypeError, match="is not available"): + _ = adapter.format_input(image) + + # Invalid type within list input + with pytest.raises( + ValueError, + match="All assets provided must be of type Image", + ): + _ = adapter.format_input(["prompt", Audio("file")]) + + # Chat message with system role + with pytest.raises(ValueError, match="System messages are not supported in Chat inputs for Gemini"): + _ = adapter.format_input(Chat(messages=[{"role": "system", "content": "prompt"}])) + + # Chat message with invalid role + with pytest.raises(ValueError, match="Invalid message role"): + _ = adapter.format_input(Chat(messages=[{"content": "prompt"}])) + + # Chat message with invalid content type + with pytest.raises(ValueError, match="Invalid content type"): + _ = adapter.format_input(Chat(messages=[{"role": "user", "content": {"foo": "bar"}}])) + + # Chat message with user role and no content + with pytest.raises(ValueError, match="Content is required for user messages"): + _ = adapter.format_input(Chat(messages=[{"role": "user"}])) + + # Chat message with assistant role and neither content nor tool calls + with pytest.raises(ValueError, match="Either content or tool calls is required for assistant messages"): + _ = adapter.format_input(Chat(messages=[{"role": "assistant"}])) + + # Chat message with tool role and no content + with pytest.raises(ValueError, match="Content and tool name are required for tool messages"): + _ = adapter.format_input(Chat(messages=[{"role": "tool", "tool_name": "tool_name"}])) + + # Chat message with tool role and no tool name + with pytest.raises(ValueError, match="Content and tool name are required for tool messages"): + _ = adapter.format_input(Chat(messages=[{"role": "tool", "content": "response"}])) def test_gemini_type_adapter_output_invalid(adapter): @@ -243,3 +326,53 @@ def test_gemini_type_adapter_output_literal(adapter): assert len(result["response_schema"].__members__) == 2 assert result["response_schema"].bar.value == "bar" assert result["response_schema"].fuzz.value == "fuzz" + + +def test_gemini_type_adapter_format_tools(adapter): + tools = [ + ToolDef( + name="tool_name", + description="tool_description", + parameters={"foo": {"type": "string"}}, + required=["foo"], + ), + ToolDef( + name="tool_name_2", + description="tool_description_2", + parameters={ + "foo": {"type": "string"}, + "bar": {"type": "integer"} + }, + required=["bar"], + ), + ] + result = adapter.format_tools(tools) + assert result == [ + { + "function_declarations": [{ + "name": "tool_name", + "description": "tool_description", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string"} + }, + "required": ["foo"], + }, + }], + }, + { + "function_declarations": [{ + "name": "tool_name_2", + "description": "tool_description_2", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string"}, + "bar": {"type": "integer"} + }, + "required": ["bar"], + }, + }], + }, + ] diff --git a/tests/models/test_llamacpp.py b/tests/models/test_llamacpp.py index 9b881b7de..79c71f0cd 100644 --- a/tests/models/test_llamacpp.py +++ b/tests/models/test_llamacpp.py @@ -12,6 +12,7 @@ LlamaCppTypeAdapter, from_llamacpp ) +from outlines.outputs import Output, StreamingOutput from outlines.types.dsl import Regex, CFG @@ -71,12 +72,12 @@ def ebnf_grammar(): def test_llamacpp_simple(model): - result = model.generate("Respond with one word. Not more.", None) - assert isinstance(result, str) + result = model("Respond with one word. Not more.", None) + assert isinstance(result, Output) def test_llamacpp_chat(model): - result = model.generate( + result = model( Chat( messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -85,14 +86,14 @@ def test_llamacpp_chat(model): ), max_tokens=10 ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_llamacpp_regex(model): result = model("Respond with one word. Not more.", Regex(r"[0-9]")) - assert isinstance(result, str) - assert int(result) - assert len(result) == 1 + assert isinstance(result, Output) + assert int(result.content) + assert len(result.content) == 1 def test_llamacpp_json(model): @@ -100,8 +101,8 @@ class Foo(BaseModel): bar: str result = model("foo? Respond with one word.", Foo, max_tokens=100) - assert isinstance(result, str) - assert "bar" in json.loads(result) + assert isinstance(result, Output) + assert "bar" in json.loads(result.content) def test_llamacpp_choice(model): @@ -110,12 +111,14 @@ class Foo(Enum): foor = "Foo" result = model("foo?", Foo) - assert result == "Foo" or result == "Bar" + assert isinstance(result, Output) + assert result.content == "Foo" or result.content == "Bar" def test_llamacpp_cfg(model, ebnf_grammar): response = model("Respond with one word. Not more.", CFG(ebnf_grammar)) - assert response in ["yes", "no"] + assert isinstance(response, Output) + assert response.content in ["yes", "no"] def test_llamacpp_cfg_outlines_core(model, lark_grammar): @@ -131,15 +134,16 @@ def test_llamacpp_cfg_outlines_core(model, lark_grammar): def test_llamacpp_text_stop(model): - result = model.generate("Write the letter a.", None, stop="a", max_tokens=100) - assert "a" not in result + result = model("Write the letter a.", None, stop="a", max_tokens=100) + assert isinstance(result, Output) + assert "a" not in result.content def test_llamacpp_stream_simple(model): generator = model.stream("Respond with one word. Not more.", None) for x in generator: - assert isinstance(x, str) + assert isinstance(x, StreamingOutput) def test_llamacpp_stream_chat(model): @@ -153,14 +157,16 @@ def test_llamacpp_stream_chat(model): max_tokens=10 ) for x in generator: - assert isinstance(x, str) + assert isinstance(x, StreamingOutput) def test_llamacpp_stream_regex(model): generator = model.stream("Respond with one word. Not more.", Regex(r"[0-9]")) x = next(generator) - assert isinstance(x, str) + assert isinstance(x, StreamingOutput) + assert int(x.content) + assert len(x.content) == 1 def test_llamacpp_stream_json(model): @@ -170,7 +176,8 @@ class Foo(BaseModel): generator = model.stream("foo?", Foo) x = next(generator) - assert x == "{" + assert isinstance(x, StreamingOutput) + assert "{" in x.content def test_llamacpp_stream_cfg(model, ebnf_grammar): @@ -178,7 +185,8 @@ def test_llamacpp_stream_cfg(model, ebnf_grammar): for chunk in model.stream( "Respond with one word. Not more.", CFG(ebnf_grammar) ): - response += chunk + assert isinstance(chunk, StreamingOutput) + response += chunk.content assert response in ["yes", "no"] @@ -187,7 +195,7 @@ def test_llamacpp_stream_cfg_outlines_core(model, lark_grammar): NotImplementedError, match="Outlines Core does not support context-free grammar." ): - for chunk in model.stream( + for _ in model.stream( "Respond with one word. Not more.", CFG(lark_grammar), backend="outlines_core" @@ -203,15 +211,16 @@ class Foo(Enum): generator = model.stream("foo?", Foo) x = next(generator) - assert x[0] in ("B", "F") + assert isinstance(x, StreamingOutput) + assert x.content[0] in ("B", "F") def test_llamacpp_stream_text_stop(model): generator = model.stream("Write the letter a.", None, stop="a", max_tokens=100) result = next(generator) - assert isinstance(result, str) - assert result != "a" + assert isinstance(result, StreamingOutput) + assert result.content != "a" def test_llamacpp_batch(model): diff --git a/tests/models/test_llamacpp_type_adapter.py b/tests/models/test_llamacpp_type_adapter.py index 403b3589c..680055e32 100644 --- a/tests/models/test_llamacpp_type_adapter.py +++ b/tests/models/test_llamacpp_type_adapter.py @@ -8,6 +8,7 @@ from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor from outlines.inputs import Chat, Image from outlines.models.llamacpp import LlamaCppTypeAdapter +from outlines.tools import ToolDef @pytest.fixture @@ -67,3 +68,15 @@ def test_llamacpp_type_adapter_format_output_type(adapter, logits_processor): assert isinstance(formatted, LogitsProcessorList) assert formatted[0].index == logits_processor.index assert formatted[0].tensor_library_name == logits_processor.tensor_library_name + + +def test_llamacpp_type_adapter_tools(adapter): + with pytest.raises( + NotImplementedError, + match="LlamaCpp does not support tools." + ): + adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + adapter.format_tools(None) diff --git a/tests/models/test_mlxlm.py b/tests/models/test_mlxlm.py index fc2ec6f39..f5953d99f 100644 --- a/tests/models/test_mlxlm.py +++ b/tests/models/test_mlxlm.py @@ -3,6 +3,8 @@ from enum import Enum from typing import Generator +from pydantic import BaseModel + import outlines from outlines.types import Regex from outlines.models.mlxlm import ( @@ -11,7 +13,7 @@ from_mlxlm ) from outlines.models.transformers import TransformerTokenizer -from pydantic import BaseModel +from outlines.outputs import Output, StreamingOutput try: import mlx_lm @@ -55,14 +57,14 @@ def test_mlxlm_tokenizer(model): @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") def test_mlxlm_simple(model): - result = model.generate("Respond with one word. Not more.", None) - assert isinstance(result, str) + result = model("Respond with one word. Not more.", None) + assert isinstance(result, Output) @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") def test_mlxlm_call(model): result = model("Respond with one word. Not more.") - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") @@ -80,15 +82,15 @@ def test_mlxlm_invalid_inference_kwargs(model): @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") def test_mlxlm_inference_kwargs(model): result = model("Write a short story about a cat.", max_tokens=2) - assert isinstance(result, str) - assert len(result) < 20 + assert isinstance(result, Output) + assert len(result.content) < 20 @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") def test_mlxlm_regex(model): result = model("Give a number between 0 and 9.", Regex(r"[0-9]")) - assert isinstance(result, str) - assert re.match(r"[0-9]", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]", result.content) @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") @@ -97,7 +99,7 @@ class Character(BaseModel): name: str result = model("Create a character with a name.", Character) - assert "name" in result + assert "name" in result.content @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") @@ -107,7 +109,7 @@ class Foo(Enum): dog = "dog" result = model("Cat or dog?", Foo) - assert result in ["cat", "dog"] + assert result.content in ["cat", "dog"] @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") @@ -116,7 +118,7 @@ def test_mlxlm_stream_text_stop(model): "Respond with one word. Not more.", None, max_tokens=100 ) assert isinstance(generator, Generator) - assert isinstance(next(generator), str) + assert isinstance(next(generator), StreamingOutput) @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") diff --git a/tests/models/test_mlxlm_type_adapter.py b/tests/models/test_mlxlm_type_adapter.py index dd2feb16d..dcd627ff7 100644 --- a/tests/models/test_mlxlm_type_adapter.py +++ b/tests/models/test_mlxlm_type_adapter.py @@ -1,5 +1,5 @@ -import pytest import io +import pytest from outlines_core import Index, Vocabulary from PIL import Image as PILImage @@ -7,6 +7,7 @@ from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor from outlines.inputs import Chat, Image from outlines.models.mlxlm import MLXLMTypeAdapter +from outlines.tools import ToolDef try: import mlx_lm @@ -82,3 +83,16 @@ def test_mlxlm_type_adapter_format_output_type(adapter, logits_processor): assert isinstance(formatted, list) assert len(formatted) == 1 assert isinstance(formatted[0], OutlinesCoreLogitsProcessor) + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_type_adapter_tools(adapter): + with pytest.raises( + NotImplementedError, + match="MLXLM does not support tools." + ): + adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + adapter.format_tools(None) diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py index c6e39605b..32178f985 100644 --- a/tests/models/test_ollama.py +++ b/tests/models/test_ollama.py @@ -11,6 +11,7 @@ import outlines from outlines.inputs import Chat, Image, Video from outlines.models import AsyncOllama, Ollama +from outlines.outputs import Output, StreamingOutput MODEL_NAME = "tinyllama" @@ -73,16 +74,16 @@ def test_ollama_init_from_client(): def test_ollama_wrong_inference_parameters(model): with pytest.raises(TypeError, match="got an unexpected"): - model.generate( + model( "Respond with one word. Not more.", None, foo=10 ) def test_ollama_simple(model): - result = model.generate( + result = model( "Respond with one word. Not more.", None ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_ollama_direct(model_no_model_name): @@ -91,21 +92,21 @@ def test_ollama_direct(model_no_model_name): None, model=MODEL_NAME, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_ollama_simple_vision(image, model): # This is not using a vision model, so it's not able to describe # the image, but we're still checking the model input syntax - result = model.generate( + result = model( ["What does this logo represent?", Image(image)], model=MODEL_NAME, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_ollama_chat(image, model): - result = model.generate( + result = model( Chat( [ {"role": "system", "content": "You are a helpful assistant."}, @@ -117,7 +118,7 @@ def test_ollama_chat(image, model): ), model=MODEL_NAME, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_ollama_json(model): @@ -125,8 +126,8 @@ class Foo(BaseModel): foo: Annotated[str, Field(max_length=1)] result = model("Respond with one word. Not more.", Foo) - assert isinstance(result, str) - assert "foo" in json.loads(result) + assert isinstance(result, Output) + assert "foo" in json.loads(result.content) def test_ollama_wrong_output_type(model): @@ -135,20 +136,20 @@ class Foo(Enum): foor = "Foo" with pytest.raises(TypeError, match="is not supported"): - model.generate("foo?", Foo) + model("foo?", Foo) def test_ollama_wrong_input_type(model, image): with pytest.raises(TypeError, match="is not available"): - model.generate({"foo?": "bar?"}, None) + model({"foo?": "bar?"}, None) with pytest.raises(ValueError, match="All assets provided must be of type Image"): - model.generate(["foo?", Image(image), Video("")], None) + model(["foo?", Image(image), Video("")], None) def test_ollama_stream(model): generator = model.stream("Write a sentence about a cat.") - assert isinstance(next(generator), str) + assert isinstance(next(generator), StreamingOutput) def test_ollama_stream_json(model_no_model_name): @@ -157,8 +158,8 @@ class Foo(BaseModel): generator = model_no_model_name.stream("Create a character.", Foo, model=MODEL_NAME) generated_text = [] - for text in generator: - generated_text.append(text) + for chunk in generator: + generated_text.append(chunk.content) assert "foo" in json.loads("".join(generated_text)) @@ -188,17 +189,17 @@ def test_ollama_async_init_from_client(): @pytest.mark.asyncio async def test_ollama_async_wrong_inference_parameters(async_model): with pytest.raises(TypeError, match="got an unexpected"): - await async_model.generate( + await async_model( "Respond with one word. Not more.", None, foo=10 ) @pytest.mark.asyncio async def test_ollama_async_simple(async_model): - result = await async_model.generate( + result = await async_model( "Respond with one word. Not more.", None ) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -208,23 +209,23 @@ async def test_ollama_async_direct(async_model_no_model_name): None, model=MODEL_NAME, ) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio async def test_ollama_async_simple_vision(image, async_model): # This is not using a vision model, so it's not able to describe # the image, but we're still checking the model input syntax - result = await async_model.generate( + result = await async_model( ["What does this logo represent?", Image(image)], model=MODEL_NAME, ) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio async def test_ollama_async_chat(image, async_model): - result = await async_model.generate( + result = await async_model( Chat( [ {"role": "system", "content": "You are a helpful assistant."}, @@ -236,7 +237,7 @@ async def test_ollama_async_chat(image, async_model): ), model=MODEL_NAME, ) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -245,8 +246,8 @@ class Foo(BaseModel): foo: Annotated[str, Field(max_length=1)] result = await async_model("Respond with one word. Not more.", Foo) - assert isinstance(result, str) - assert "foo" in json.loads(result) + assert isinstance(result, Output) + assert "foo" in json.loads(result.content) @pytest.mark.asyncio @@ -256,19 +257,19 @@ class Foo(Enum): foor = "Foo" with pytest.raises(TypeError, match="is not supported"): - await async_model.generate("foo?", Foo) + await async_model("foo?", Foo) @pytest.mark.asyncio async def test_ollama_async_wrong_input_type(async_model): with pytest.raises(TypeError, match="is not available"): - await async_model.generate({"foo?": "bar?"}, None) + await async_model({"foo?": "bar?"}, None) @pytest.mark.asyncio async def test_ollama_async_stream(async_model): async_generator = async_model.stream("Write a sentence about a cat.") - assert isinstance(await async_generator.__anext__(), str) + assert isinstance(await async_generator.__anext__(), StreamingOutput) @pytest.mark.asyncio @@ -279,7 +280,7 @@ class Foo(BaseModel): async_generator = async_model_no_model_name.stream("Create a character.", Foo, model=MODEL_NAME) generated_text = [] async for chunk in async_generator: - generated_text.append(chunk) + generated_text.append(chunk.content) assert "foo" in json.loads("".join(generated_text)) diff --git a/tests/models/test_ollama_type_adapter.py b/tests/models/test_ollama_type_adapter.py index 2061831a6..1545e06e4 100644 --- a/tests/models/test_ollama_type_adapter.py +++ b/tests/models/test_ollama_type_adapter.py @@ -10,6 +10,7 @@ from outlines.inputs import Chat, Image from outlines.models.ollama import OllamaTypeAdapter +from outlines.tools import ToolDef from outlines.types import cfg, json_schema, regex if sys.version_info >= (3, 12): @@ -168,3 +169,15 @@ def test_ollama_type_adapter_json_schema_str(adapter, schema): def test_ollama_type_adapter_json_schema_dict(adapter, schema): result = adapter.format_output_type(json_schema(schema)) assert result == schema + + +def test_ollama_type_adapter_tools(adapter): + with pytest.raises( + NotImplementedError, + match="Tools are not available for Ollama." + ): + adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + adapter.format_tools(None) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 249785d30..ff9fb02f9 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -1,7 +1,7 @@ import io import json import os -from typing import Annotated, Generator, AsyncGenerator +from typing import Annotated, AsyncGenerator, Generator import pytest from PIL import Image as PILImage @@ -11,9 +11,16 @@ import outlines from outlines.inputs import Chat, Image, Video from outlines.models.openai import AsyncOpenAI, OpenAI +from outlines.outputs import ( + Output, + StreamingOutput, + StreamingToolCallOutput, + ToolCallOutput, +) +from outlines.tools import ToolDef from outlines.types import json_schema -MODEL_NAME = "gpt-4o-mini-2024-07-18" +MODEL_NAME = "gpt-4o-mini" @pytest.fixture(scope="session") @@ -45,6 +52,27 @@ def image(): return image +@pytest.fixture +def tools(): + return [ + ToolDef( + name="get_weather", + description="Get the current weather for a given city", + parameters={"city": {"type": "string"}}, + required=["city"], + ), + ToolDef( + name="get_user_info", + description="Get the current user info", + parameters={ + "first_name": {"type": "string"}, + "last_name": {"type": "string"} + }, + required=["last_name"], + ), + ] + + @pytest.fixture(scope="session") def model(api_key): return OpenAI(OpenAIClient(api_key=api_key), MODEL_NAME) @@ -83,118 +111,152 @@ def test_openai_init_from_client(api_key): def test_openai_wrong_inference_parameters(model): with pytest.raises(TypeError, match="got an unexpected"): - model.generate("prompt", foo=10) - - -def test_openai_wrong_input_type(model, image): - class Foo: - def __init__(self, foo): - self.foo = foo - - with pytest.raises(TypeError, match="is not available"): - model.generate(Foo("prompt")) - - with pytest.raises(ValueError, match="All assets provided must be of type Image"): - model.generate(["foo?", Image(image), Video("")]) - - -def test_openai_wrong_output_type(model): - class Foo: - def __init__(self, foo): - self.foo = foo - - with pytest.raises(TypeError, match="is not available"): - model.generate("prompt", Foo(1)) + model("prompt", foo=10) @pytest.mark.api_call -def test_openai_simple_call(model): - result = model.generate("Respond with one word. Not more.") - assert isinstance(result, str) +def test_openai_call(model): + result = model("Respond with one word. Not more.") + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call -def test_openai_simple_call_multiple_samples(model): - result = model.generate("Respond with one word. Not more.", n=2) +def test_openai_call_multiple_samples(model): + result = model("Respond with one word. Not more.", n=2) assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], str) + for output in result: + assert isinstance(output, Output) + assert isinstance(output.content, str) @pytest.mark.api_call -def test_openai_direct_call(model_no_model_name): - result = model_no_model_name( - "Respond with one word. Not more.", - model=MODEL_NAME, - ) - assert isinstance(result, str) - - -@pytest.mark.api_call -def test_openai_simple_vision(image, model): - result = model.generate(["What does this logo represent?", Image(image)]) - assert isinstance(result, str) +def test_openai_vision(image, model): + result = model(["What does this logo represent?", Image(image)]) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call def test_openai_chat(image, model): - result = model.generate(Chat(messages=[ + result = model(Chat(messages=[ {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": ["What does this logo represent?", Image(image)] }, ]), max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call -def test_openai_simple_pydantic(model): +def test_openai_pydantic(model): class Foo(BaseModel): bar: int - result = model.generate("foo?", Foo) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = model("foo?", Foo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) @pytest.mark.api_call -def test_openai_simple_pydantic_refusal(model): +def test_openai_pydantic_refusal(model): class Foo(BaseModel): bar: Annotated[str, Field(int, pattern=r"^\d+$")] with pytest.raises(TypeError, match="OpenAI does not support your schema"): - _ = model.generate("foo?", Foo) + _ = model("foo?", Foo) @pytest.mark.api_call -def test_openai_simple_vision_pydantic(image, model): +def test_openai_vision_pydantic(image, model): class Logo(BaseModel): name: int - result = model.generate(["What does this logo represent?", Image(image)], Logo) - assert isinstance(result, str) - assert "name" in json.loads(result) + result = model(["What does this logo represent?", Image(image)], Logo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "name" in json.loads(result.content) @pytest.mark.api_call -def test_openai_simple_json_schema(model): +def test_openai_json_schema(model): class Foo(BaseModel): bar: int schema = json.dumps(Foo.model_json_schema()) - result = model.generate("foo?", json_schema(schema)) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = model("foo?", json_schema(schema)) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) + + +@pytest.mark.api_call +def test_openai_tools(model, tools): + result = model( + "What is the weather in Tokyo?", + tools=tools, + max_tokens=1024, + ) + assert isinstance(result, Output) + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert isinstance(tool_call, ToolCallOutput) + assert tool_call.name == "get_weather" + assert tool_call.args == {"city": "Tokyo"} + assert tool_call.id is not None + + +@pytest.mark.api_call +def test_openai_tools_chat(model, tools): + chat = Chat(messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the weather in Tokyo?"}, + ]) + generator = outlines.Generator(model, tools=tools) + result = generator(chat) + chat.add_output(result) + chat.add_tool_message( + tool_call_id=result.tool_calls[0].id, + content="The weather in Tokyo is sunny.", + ) + chat.add_user_message("Is it a good weather to go out?") + result = generator(chat) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.api_call def test_openai_streaming(model): result = model.stream("Respond with one word. Not more.") assert isinstance(result, Generator) - assert isinstance(next(result), str) + assert isinstance(next(result), StreamingOutput) + assert isinstance(next(result).content, str) + + +@pytest.mark.api_call +def test_openai_streaming_tools(model, tools): + result = model.stream( + "What is the weather in Tokyo?", + tools=tools, + max_tokens=1024, + ) + assert isinstance(result, Generator) + for chunk in result: + assert isinstance(chunk, StreamingOutput) + assert chunk.content is None + assert chunk.tool_calls is not None + assert len(chunk.tool_calls) == 1 + tool_call = chunk.tool_calls[0] + assert isinstance(tool_call, StreamingToolCallOutput) + assert tool_call.name == "get_weather" + assert isinstance(tool_call.args, str) + assert tool_call.id is not None def test_openai_batch(model): @@ -223,47 +285,26 @@ def test_openai_async_init_from_client(api_key): @pytest.mark.asyncio async def test_openai_async_wrong_inference_parameters(async_model): with pytest.raises(TypeError, match="got an unexpected"): - await async_model.generate("prompt", foo=10) - - -@pytest.mark.asyncio -async def test_openai_async_wrong_input_type(async_model, image): - class Foo: - def __init__(self, foo): - self.foo = foo - - with pytest.raises(TypeError, match="is not available"): - await async_model.generate(Foo("prompt")) - - with pytest.raises(ValueError, match="All assets provided must be of type Image"): - await async_model.generate(["foo?", Image(image), Video("")]) - - -@pytest.mark.asyncio -async def test_openai_async_wrong_output_type(async_model): - class Foo: - def __init__(self, foo): - self.foo = foo - - with pytest.raises(TypeError, match="is not available"): - await async_model.generate("prompt", Foo(1)) + await async_model("prompt", foo=10) @pytest.mark.asyncio @pytest.mark.api_call -async def test_openai_async_simple_call(async_model): - result = await async_model.generate("Respond with one word. Not more.") - assert isinstance(result, str) +async def test_openai_async_call(async_model): + result = await async_model("Respond with one word. Not more.") + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.asyncio @pytest.mark.api_call -async def test_openai_async_simple_call_multiple_samples(async_model): - result = await async_model.generate("Respond with one word. Not more.", n=2) +async def test_openai_async_call_multiple_samples(async_model): + result = await async_model("Respond with one word. Not more.", n=2) assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], str) + for output in result: + assert isinstance(output, Output) + assert isinstance(output.content, str) @pytest.mark.asyncio @@ -273,72 +314,96 @@ async def test_openai_async_direct_call(async_model_no_model_name): "Respond with one word. Not more.", model=MODEL_NAME, ) - assert isinstance(result, str) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.asyncio @pytest.mark.api_call -async def test_openai_async_simple_vision(image, async_model): - result = await async_model.generate(["What does this logo represent?", Image(image)]) - assert isinstance(result, str) +async def test_openai_async_vision(image, async_model): + result = await async_model(["What does this logo represent?", Image(image)]) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.asyncio @pytest.mark.api_call async def test_openai_async_chat(image, async_model): - result = await async_model.generate(Chat(messages=[ + result = await async_model(Chat(messages=[ {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": ["What does this logo represent?", Image(image)] }, ]), max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) + assert isinstance(result.content, str) @pytest.mark.asyncio @pytest.mark.api_call -async def test_openai_async_simple_pydantic(async_model): +async def test_openai_async_pydantic(async_model): class Foo(BaseModel): bar: int - result = await async_model.generate("foo?", Foo) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = await async_model("foo?", Foo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) @pytest.mark.asyncio @pytest.mark.api_call -async def test_openai_async_simple_pydantic_refusal(async_model): +async def test_openai_async_pydantic_refusal(async_model): class Foo(BaseModel): bar: Annotated[str, Field(int, pattern=r"^\d+$")] with pytest.raises(TypeError, match="OpenAI does not support your schema"): - _ = await async_model.generate("foo?", Foo) + _ = await async_model("foo?", Foo) @pytest.mark.asyncio @pytest.mark.api_call -async def test_openai_async_simple_vision_pydantic(image, async_model): +async def test_openai_async_vision_pydantic(image, async_model): class Logo(BaseModel): name: int - result = await async_model.generate(["What does this logo represent?", Image(image)], Logo) - assert isinstance(result, str) - assert "name" in json.loads(result) + result = await async_model(["What does this logo represent?", Image(image)], Logo) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "name" in json.loads(result.content) @pytest.mark.asyncio @pytest.mark.api_call -async def test_openai_async_simple_json_schema(async_model): +async def test_openai_async_json_schema(async_model): class Foo(BaseModel): bar: int schema = json.dumps(Foo.model_json_schema()) - result = await async_model.generate("foo?", json_schema(schema)) - assert isinstance(result, str) - assert "bar" in json.loads(result) + result = await async_model("foo?", json_schema(schema)) + assert isinstance(result, Output) + assert isinstance(result.content, str) + assert "bar" in json.loads(result.content) + + +@pytest.mark.asyncio +@pytest.mark.api_call +async def test_openai_async_tools(async_model, tools): + result = await async_model( + "What is the weather in Tokyo?", + tools=tools, + max_tokens=1024, + ) + assert isinstance(result, Output) + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert isinstance(tool_call, ToolCallOutput) + assert tool_call.name == "get_weather" + assert tool_call.args == {"city": "Tokyo"} + assert tool_call.id is not None @pytest.mark.asyncio @@ -347,10 +412,32 @@ async def test_openai_async_streaming(async_model): result = async_model.stream("Respond with a single word.") assert isinstance(result, AsyncGenerator) async for chunk in result: - assert isinstance(chunk, str) + assert isinstance(chunk, StreamingOutput) + assert isinstance(chunk.content, str) break # Just check the first chunk +@pytest.mark.asyncio +@pytest.mark.api_call +async def test_openai_async_streaming_tools(async_model, tools): + result = async_model.stream( + "What is the weather in Tokyo?", + tools=tools, + max_tokens=1024, + ) + assert isinstance(result, AsyncGenerator) + async for chunk in result: + assert isinstance(chunk, StreamingOutput) + assert chunk.content is None + assert chunk.tool_calls is not None + assert len(chunk.tool_calls) == 1 + tool_call = chunk.tool_calls[0] + assert isinstance(tool_call, StreamingToolCallOutput) + assert tool_call.name == "get_weather" + assert isinstance(tool_call.args, str) + assert tool_call.id is not None + + @pytest.mark.asyncio async def test_openai_async_batch(async_model): with pytest.raises(NotImplementedError, match="does not support"): diff --git a/tests/models/test_openai_type_adapter.py b/tests/models/test_openai_type_adapter.py index 5f523b8ea..cfcca4dc5 100644 --- a/tests/models/test_openai_type_adapter.py +++ b/tests/models/test_openai_type_adapter.py @@ -12,6 +12,7 @@ from outlines import cfg, json_schema, regex from outlines.inputs import Chat, Image from outlines.models.openai import OpenAITypeAdapter +from outlines.tools import ToolDef if sys.version_info >= (3, 12): from typing import TypedDict @@ -87,7 +88,11 @@ def test_openai_type_adapter_input_chat(adapter, image): "hello", image_input, ]}, - {"role": "assistant", "content": "response"}, + {"role": "assistant", "tool_calls": [ + {"tool_name": "tool_name", "tool_call_id": "abc", "args": {"foo": "bar"}} + ]}, + {"role": "tool", "content": "response", "tool_call_id": "abc"}, + {"role": "user", "content": "prompt"}, ]) result = adapter.format_input(model_input) assert result == [ @@ -104,31 +109,65 @@ def test_openai_type_adapter_input_chat(adapter, image): }, ] }, - {"role": "assistant", "content": "response"}, + {"role": "assistant", "tool_calls": [ + { + "type": "function", + "function": { + "name": "tool_name", + "arguments": "{'foo': 'bar'}" + }, + "id": "abc" + } + ]}, + {"role": "tool", "content": "response", "tool_call_id": "abc"}, + {"role": "user", "content": "prompt"}, ] -def test_openai_type_adapter_input_invalid(adapter): +def test_openai_type_adapter_input_invalid(adapter, image): + @dataclass class Audio: file: str + # Invalid input type with pytest.raises(TypeError, match="is not available"): - _ = adapter.format_input(Audio("file")) + _ = adapter.format_input(image) + # Invalid type within list input with pytest.raises( ValueError, match="All assets provided must be of type Image", ): _ = adapter.format_input(["prompt", Audio("file")]) - with pytest.raises( - ValueError, - match="The content must be a string or a list", - ): - _ = adapter.format_input( - Chat(messages=[{"role": "user", "content": {"foo": "bar"}}]) - ) + # Chat message with invalid role + with pytest.raises(ValueError, match="Invalid message role"): + _ = adapter.format_input(Chat(messages=[{"content": "prompt"}])) + + # Chat message with invalid content type + with pytest.raises(ValueError, match="Invalid content type"): + _ = adapter.format_input(Chat(messages=[{"role": "user", "content": {"foo": "bar"}}])) + + # Chat message with user role and no content + with pytest.raises(ValueError, match="Content is required for user messages"): + _ = adapter.format_input(Chat(messages=[{"role": "user"}])) + + # Chat message with system role and no content + with pytest.raises(ValueError, match="Content is required for system messages"): + _ = adapter.format_input(Chat(messages=[{"role": "system"}])) + + # Chat message with assistant role and neither content nor tool calls + with pytest.raises(ValueError, match="Either content or tool calls is required for assistant messages"): + _ = adapter.format_input(Chat(messages=[{"role": "assistant"}])) + + # Chat message with tool role and no content + with pytest.raises(ValueError, match="Content and tool call id are required for tool messages"): + _ = adapter.format_input(Chat(messages=[{"role": "tool", "tool_call_id": "abc"}])) + + # Chat message with tool role and no tool call id + with pytest.raises(ValueError, match="Content and tool call id are required for tool messages"): + _ = adapter.format_input(Chat(messages=[{"role": "tool", "content": "response"}])) def test_openai_type_adapter_output_invalid(adapter): @@ -230,3 +269,55 @@ def test_openai_type_adapter_json_schema_dict(adapter, schema): assert isinstance(result, dict) assert result["response_format"]["json_schema"]["strict"] is True assert result["response_format"]["json_schema"]["schema"] == schema + + +def test_openai_type_adapter_format_tools(adapter): + tools = [ + ToolDef( + name="tool_name", + description="tool_description", + parameters={"foo": {"type": "string"}}, + required=["foo"], + ), + ToolDef( + name="tool_name_2", + description="tool_description_2", + parameters={ + "foo": {"type": "string"}, + "bar": {"type": "integer"} + }, + required=["bar"], + ), + ] + result = adapter.format_tools(tools) + assert result == [ + { + "type": "function", + "function": { + "name": "tool_name", + "description": "tool_description", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string"} + }, + "required": ["foo"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "tool_name_2", + "description": "tool_description_2", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string"}, + "bar": {"type": "integer"}, + }, + "required": ["bar"], + }, + }, + }, + ] diff --git a/tests/models/test_sglang.py b/tests/models/test_sglang.py index 741a013fe..cae4cafe0 100644 --- a/tests/models/test_sglang.py +++ b/tests/models/test_sglang.py @@ -15,6 +15,7 @@ from outlines.inputs import Chat, Image from outlines.models.sglang import SGLang, AsyncSGLang, from_sglang +from outlines.outputs import Output, StreamingOutput from outlines.types.dsl import CFG, Regex, JsonSchema from tests.test_utils.mock_openai_client import MockOpenAIClient, MockAsyncOpenAIClient @@ -231,7 +232,7 @@ def test_sglang_init(): def test_sglang_sync_simple_call(sync_model): result = sync_model("Respond with a single word.",) - assert isinstance(result, str) + assert isinstance(result, Output) def test_sglang_sync_streaming(sync_model_no_model_name): @@ -240,7 +241,7 @@ def test_sglang_sync_streaming(sync_model_no_model_name): model=sglang_model_name, ) assert isinstance(result, Generator) - assert isinstance(next(result), str) + assert isinstance(next(result), StreamingOutput) def test_sglang_sync_batch(sync_model): @@ -252,7 +253,7 @@ def test_sglang_sync_batch(sync_model): def test_sglang_sync_vision(sync_model): result = sync_model(["hello", image_input], max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) def test_sglang_sync_vision_chat(sync_model): @@ -267,15 +268,15 @@ def test_sglang_sync_vision_chat(sync_model): ]), max_tokens=10, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_sglang_sync_multiple_samples(sync_model): result = sync_model("Respond with a single word.", n=2) assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], str) + assert isinstance(result[0], Output) + assert isinstance(result[1], Output) def test_sglang_sync_json(sync_model): @@ -284,14 +285,14 @@ def test_sglang_sync_json(sync_model): + ' {"bar": {"type": "string"}}}' ) result = sync_model("foo?", JsonSchema(json_string), max_tokens=10) - assert isinstance(result, str) - assert "bar" in result + assert isinstance(result, Output) + assert "bar" in result.content def test_sglang_sync_regex(sync_model): result = sync_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10) - assert isinstance(result, str) - assert re.match(r"[0-9]{3}", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]{3}", result.content) def test_sglang_sync_cfg(sync_model): @@ -300,14 +301,14 @@ def test_sglang_sync_cfg(sync_model): match="SGLang grammar-based structured outputs expects an EBNF" ): result = sync_model("foo?", CFG(EBNF_YES_NO_GRAMMAR), max_tokens=10) - assert isinstance(result, str) - assert result in ["yes", "no"] + assert isinstance(result, Output) + assert result.content in ["yes", "no"] @pytest.mark.asyncio async def test_sglang_async_simple_call(async_model): result = await async_model("Respond with a single word.",) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -318,7 +319,7 @@ async def test_sglang_async_streaming(async_model_no_model_name): ) assert isinstance(result, AsyncGenerator) async for chunk in result: - assert isinstance(chunk, str) + assert isinstance(chunk, StreamingOutput) break # Just check the first chunk @@ -333,7 +334,7 @@ async def test_sglang_async_batch(async_model): @pytest.mark.asyncio async def test_sglang_async_vision(async_model): result = await async_model(["hello", image_input], max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -349,7 +350,7 @@ async def test_sglang_async_vision_chat(async_model): ]), max_tokens=10, ) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -357,8 +358,8 @@ async def test_sglang_async_multiple_samples(async_model): result = await async_model("Respond with a single word.", n=2) assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], str) + assert isinstance(result[0], Output) + assert isinstance(result[1], Output) @pytest.mark.asyncio @@ -368,19 +369,19 @@ async def test_sglang_async_json(async_model): + ' {"bar": {"type": "string"}}}' ) result = await async_model("foo?", JsonSchema(json_string), max_tokens=10) - assert isinstance(result, str) - assert "bar" in result + assert isinstance(result, Output) + assert "bar" in result.content @pytest.mark.asyncio async def test_sglang_async_regex(async_model): result = await async_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10) - assert isinstance(result, str) - assert re.match(r"[0-9]{3}", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]{3}", result.content) @pytest.mark.asyncio async def test_sglang_async_cfg(async_model): result = await async_model("foo?", CFG(EBNF_YES_NO_GRAMMAR), max_tokens=10) - assert isinstance(result, str) - assert result in ["yes", "no"] + assert isinstance(result, Output) + assert result.content in ["yes", "no"] diff --git a/tests/models/test_sglang_type_adapter.py b/tests/models/test_sglang_type_adapter.py index f07ccac49..6657dc14c 100644 --- a/tests/models/test_sglang_type_adapter.py +++ b/tests/models/test_sglang_type_adapter.py @@ -7,6 +7,7 @@ from outlines.inputs import Chat, Image from outlines.models.sglang import SGLangTypeAdapter +from outlines.tools import ToolDef from outlines.types import CFG, JsonSchema @@ -166,3 +167,15 @@ def test_sglang_type_adapter_output_type( assert type_adapter.format_output_type(int) == { "extra_body": {"regex": "([+-]?(0|[1-9][0-9]*))"} } + + +def test_sglang_type_adapter_tools(type_adapter): + with pytest.raises( + NotImplementedError, + match="Tools are not available for SGLang." + ): + type_adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + type_adapter.format_tools(None) diff --git a/tests/models/test_tgi.py b/tests/models/test_tgi.py index 538f5616f..490521cf1 100644 --- a/tests/models/test_tgi.py +++ b/tests/models/test_tgi.py @@ -7,6 +7,7 @@ from huggingface_hub import InferenceClient, AsyncInferenceClient from outlines.models.tgi import TGI, AsyncTGI, from_tgi +from outlines.outputs import Output, StreamingOutput from outlines.types.dsl import CFG, Regex, JsonSchema from tests.test_utils.mock_tgi_client import MockTGIInferenceClient, MockAsyncTGIInferenceClient @@ -42,7 +43,7 @@ 'max_new_tokens': 10, 'stream': True }, - ["foo", "bar"] + [Output(content="foo"), Output(content="bar")] ), ( { @@ -108,7 +109,7 @@ def test_tgi_init(): def test_tgi_sync_simple_call(sync_model): result = sync_model("Respond with a single word.", max_new_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) def test_tgi_sync_streaming(sync_model): @@ -117,7 +118,7 @@ def test_tgi_sync_streaming(sync_model): max_new_tokens=10, ) assert isinstance(result, Generator) - assert isinstance(next(result), str) + assert isinstance(next(result), StreamingOutput) def test_tgi_sync_batch(sync_model): @@ -130,14 +131,14 @@ def test_tgi_sync_batch(sync_model): def test_tgi_sync_json(sync_model): json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}, "required": ["bar"]}' result = sync_model("foo?", JsonSchema(json_string), max_new_tokens=10) - assert isinstance(result, str) - assert "bar" in result + assert isinstance(result, Output) + assert "bar" in result.content def test_tgi_sync_regex(sync_model): result = sync_model("foo?", Regex(r"[0-9]{3}"), max_new_tokens=10) - assert isinstance(result, str) - assert re.match(r"[0-9]{3}", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]{3}", result.content) def test_tgi_sync_cfg(sync_model): @@ -151,7 +152,7 @@ def test_tgi_sync_cfg(sync_model): @pytest.mark.asyncio async def test_tgi_async_simple_call(async_model): result = await async_model("Respond with a single word.", max_new_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -159,7 +160,7 @@ async def test_tgi_async_streaming(async_model): result = async_model.stream("Respond with a single word.", max_new_tokens=10) assert isinstance(result, AsyncGenerator) async for chunk in result: - assert isinstance(chunk, str) + assert isinstance(chunk, StreamingOutput) break # Just check the first chunk @@ -175,15 +176,15 @@ async def test_tgi_async_batch(async_model): async def test_tgi_async_json(async_model): json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}, "required": ["bar"]}' result = await async_model("foo?", JsonSchema(json_string), max_new_tokens=10) - assert isinstance(result, str) - assert "bar" in result + assert isinstance(result, Output) + assert "bar" in result.content @pytest.mark.asyncio async def test_tgi_async_regex(async_model): result = await async_model("foo?", Regex(r"[0-9]{3}"), max_new_tokens=10) - assert isinstance(result, str) - assert re.match(r"[0-9]{3}", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]{3}", result.content) @pytest.mark.asyncio diff --git a/tests/models/test_tgi_model_adapter.py b/tests/models/test_tgi_type_adapter.py similarity index 85% rename from tests/models/test_tgi_model_adapter.py rename to tests/models/test_tgi_type_adapter.py index 42c91e3d6..820974bbd 100644 --- a/tests/models/test_tgi_model_adapter.py +++ b/tests/models/test_tgi_type_adapter.py @@ -2,6 +2,7 @@ import pytest from outlines.models.tgi import TGITypeAdapter +from outlines.tools import ToolDef from outlines.types import CFG, JsonSchema @@ -86,3 +87,15 @@ def test_tgi_type_adapter_output_type_invalid( match="TGI does not support CFG-based structured outputs.", ): type_adapter.format_output_type(cfg_instance) + + +def test_tgi_type_adapter_tools(type_adapter): + with pytest.raises( + NotImplementedError, + match="Tools are not available for TGI.", + ): + type_adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + type_adapter.format_tools(None) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 10f572794..3f8f27bab 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -12,6 +12,7 @@ TransformerTokenizer, TransformersTypeAdapter, ) +from outlines.outputs import Output, StreamingOutput from outlines.types import Regex @@ -80,16 +81,16 @@ def model_bart(): def test_transformers_simple(model): - result = model.generate("Respond with one word. Not more.", None) - assert isinstance(result, str) + result = model("Respond with one word. Not more.", None) + assert isinstance(result, Output) def test_transformers_call(model, model_bart): result = model("Respond with one word. Not more.") - assert isinstance(result, str) + assert isinstance(result, Output) result = model_bart("Respond with one word. Not more.") - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_chat(model): @@ -99,12 +100,12 @@ def test_transformers_chat(model): {"role": "user", "content": "What is the capital of France?"}, ]) ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_inference_kwargs(model): result = model("Respond with one word. Not more.", max_new_tokens=100) - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_invalid_inference_kwargs(model): @@ -114,8 +115,8 @@ def test_transformers_invalid_inference_kwargs(model): def test_transformers_regex(model): result = model("Give a number between 0 and 9.", Regex(r"[0-9]")) - assert isinstance(result, str) - assert re.match(r"[0-9]", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]", result.content) def test_transformers_json(model): @@ -123,7 +124,7 @@ class Character(BaseModel): name: str result = model("Create a character with a name.", Character) - assert "name" in result + assert "name" in result.content def test_transformers_choice(model): @@ -132,12 +133,12 @@ class Foo(Enum): dog = "dog" result = model("Cat or dog?", Foo) - assert result in ["cat", "dog"] + assert result.content in ["cat", "dog"] def test_transformers_multiple_samples(model): result = model("Respond with one word. Not more.") - assert isinstance(result, str) + assert isinstance(result, Output) result = model( "Respond with one word. Not more.", num_return_sequences=2, do_sample=True ) @@ -187,8 +188,8 @@ class Foo(Enum): result = model("Cat or dog?", Foo, num_return_sequences=2, do_sample=True) assert isinstance(result, list) assert len(result) == 2 - assert result[0] in ["cat", "dog"] - assert result[1] in ["cat", "dog"] + assert result[0].content in ["cat", "dog"] + assert result[1].content in ["cat", "dog"] def test_transformers_batch_constrained(model): @@ -202,8 +203,8 @@ class Foo(Enum): ) assert isinstance(result, list) assert len(result) == 2 - assert result[0] in ["cat", "dog"] - assert result[1] in ["cat", "dog"] + assert result[0].content in ["cat", "dog"] + assert result[1].content in ["cat", "dog"] result = model.batch( ["Cat or dog?", "Cat or dog?"], @@ -216,8 +217,8 @@ class Foo(Enum): for item in result: assert isinstance(item, list) assert len(item) == 2 - assert item[0] in ["cat", "dog"] - assert item[1] in ["cat", "dog"] + assert item[0].content in ["cat", "dog"] + assert item[1].content in ["cat", "dog"] def test_transformers_streaming(model): diff --git a/tests/models/test_transformers_multimodal.py b/tests/models/test_transformers_multimodal.py index 41ff058e9..e4c40dc28 100644 --- a/tests/models/test_transformers_multimodal.py +++ b/tests/models/test_transformers_multimodal.py @@ -19,6 +19,7 @@ TransformerTokenizer, TransformersMultiModalTypeAdapter, ) +from outlines.outputs import Output, StreamingOutput from outlines.types import Regex TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration" @@ -61,12 +62,12 @@ def test_transformers_multimodal_instantiate_simple(): def test_transformers_multimodal_simple(model, image): - result = model.generate( + result = model( ["Describe this image in one sentence:", Image(image)], None, max_new_tokens=2, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_multimodal_call(model, image): @@ -74,7 +75,7 @@ def test_transformers_multimodal_call(model, image): ["Describe this image in one sentence:", Image(image)], max_new_tokens=2, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_multimodal_wrong_number_image(model, image): @@ -90,7 +91,7 @@ def test_transformers_multimodal_wrong_number_image(model, image): def test_transformers_multimodal_wrong_input_type(model): with pytest.raises(TypeError): - model.generate("invalid input", None) + model("invalid input", None) def test_transformers_multimodal_chat(model, image): @@ -107,7 +108,7 @@ def test_transformers_multimodal_chat(model, image): ]), max_new_tokens=2, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_inference_kwargs(model, image): @@ -115,7 +116,7 @@ def test_transformers_inference_kwargs(model, image): ["Describe this image in one sentence:", Image(image)], max_new_tokens=2, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_invalid_inference_kwargs(model, image): @@ -138,7 +139,7 @@ def test_transformers_several_image(model, image): ], max_new_tokens=2, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_transformers_multimodal_json(model, image): @@ -150,7 +151,8 @@ class Foo(BaseModel): Foo, max_new_tokens=10, ) - assert "name" in result + assert isinstance(result, Output) + assert "name" in result.content def test_transformers_multimodal_regex(model, image): @@ -159,8 +161,8 @@ def test_transformers_multimodal_regex(model, image): Regex(r"[0-9]") ) - assert isinstance(result, str) - assert re.match(r"[0-9]", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]", result.content) def test_transformers_multimodal_choice(model, image): @@ -173,8 +175,8 @@ class Foo(Enum): Foo, ) - assert isinstance(result, str) - assert result in ["white", "blue"] + assert isinstance(result, Output) + assert result.content in ["white", "blue"] def test_transformers_multimodal_multiple_samples(model, image): @@ -245,7 +247,7 @@ def test_transformers_multimodal_batch(model, image): def test_transformers_multimodal_deprecated_input_type(model, image): with pytest.warns(DeprecationWarning): - result = model.generate( + result = model( { "text": "Describe this image in one sentence:", "image": image, @@ -253,4 +255,4 @@ def test_transformers_multimodal_deprecated_input_type(model, image): None, max_new_tokens=2, ) - assert isinstance(result, str) + assert isinstance(result, Output) diff --git a/tests/models/test_transformers_multimodal_type_adapter.py b/tests/models/test_transformers_multimodal_type_adapter.py index 10df1b0e9..ea5b854ef 100644 --- a/tests/models/test_transformers_multimodal_type_adapter.py +++ b/tests/models/test_transformers_multimodal_type_adapter.py @@ -4,9 +4,10 @@ from outlines_core import Index, Vocabulary from transformers import AutoTokenizer, LogitsProcessorList +from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor from outlines.inputs import Chat, Image, Video from outlines.models.transformers import TransformersMultiModalTypeAdapter -from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor +from outlines.tools import ToolDef MODEL_NAME = "erwanf/gpt2-mini" @@ -85,3 +86,15 @@ def test_transformers_multimodal_type_adapter_format_output_type( formatted = adapter.format_output_type(None) assert formatted is None + + +def test_transformers_multimodal_type_adapter_tools(adapter): + with pytest.raises( + NotImplementedError, + match="TransformersMultiModal does not support tools." + ): + adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + adapter.format_tools(None) diff --git a/tests/models/test_transformers_type_adapter.py b/tests/models/test_transformers_type_adapter.py index da7bbfcd2..2f108cf88 100644 --- a/tests/models/test_transformers_type_adapter.py +++ b/tests/models/test_transformers_type_adapter.py @@ -9,7 +9,7 @@ from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor from outlines.inputs import Chat, Image from outlines.models.transformers import TransformersTypeAdapter - +from outlines.tools import ToolDef MODEL_NAME = "erwanf/gpt2-mini" @@ -71,3 +71,15 @@ def test_transformers_type_adapter_format_output_type( formatted = adapter.format_output_type(None) assert formatted is None + + +def test_transformers_type_adapter_tools(adapter): + with pytest.raises( + NotImplementedError, + match="Transformers does not support tools." + ): + adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + adapter.format_tools(None) diff --git a/tests/models/test_vllm.py b/tests/models/test_vllm.py index cd967a3cb..ff3c5f816 100644 --- a/tests/models/test_vllm.py +++ b/tests/models/test_vllm.py @@ -11,6 +11,7 @@ from outlines.inputs import Chat, Image from outlines.models.vllm import VLLM, AsyncVLLM, from_vllm +from outlines.outputs import Output, StreamingOutput from outlines.types.dsl import CFG, Regex, JsonSchema from tests.test_utils.mock_openai_client import MockOpenAIClient, MockAsyncOpenAIClient @@ -225,7 +226,7 @@ def test_vllm_init(): def test_vllm_sync_simple_call(sync_model): result = sync_model("Respond with a single word.",) - assert isinstance(result, str) + assert isinstance(result, Output) def test_vllm_sync_streaming(sync_model_no_model_name): @@ -234,7 +235,7 @@ def test_vllm_sync_streaming(sync_model_no_model_name): model=vllm_model_name, ) assert isinstance(result, Generator) - assert isinstance(next(result), str) + assert isinstance(next(result), StreamingOutput) def test_vllm_sync_batch(sync_model): @@ -246,7 +247,7 @@ def test_vllm_sync_batch(sync_model): def test_vllm_sync_vision(sync_model): result = sync_model(["hello", image_input], max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) def test_vllm_sync_vision_chat(sync_model): @@ -261,40 +262,40 @@ def test_vllm_sync_vision_chat(sync_model): ]), max_tokens=10, ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_vllm_sync_multiple_samples(sync_model): result = sync_model("Respond with a single word.", n=2) assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], str) + assert isinstance(result[0], Output) + assert isinstance(result[1], Output) def test_vllm_sync_json(sync_model): json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}}' result = sync_model("foo?", JsonSchema(json_string), max_tokens=10) - assert isinstance(result, str) - assert "bar" in result + assert isinstance(result, Output) + assert "bar" in result.content def test_vllm_sync_regex(sync_model): result = sync_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10) - assert isinstance(result, str) - assert re.match(r"[0-9]{3}", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]{3}", result.content) def test_vllm_sync_cfg(sync_model): result = sync_model("foo?", CFG(YES_NO_GRAMMAR), max_tokens=10) - assert isinstance(result, str) - assert result in ["yes", "no"] + assert isinstance(result, Output) + assert result.content in ["yes", "no"] @pytest.mark.asyncio async def test_vllm_async_simple_call(async_model): result = await async_model("Respond with a single word.",) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -305,7 +306,7 @@ async def test_vllm_async_streaming(async_model_no_model_name): ) assert isinstance(result, AsyncGenerator) async for chunk in result: - assert isinstance(chunk, str) + assert isinstance(chunk, StreamingOutput) break # Just check the first chunk @@ -320,7 +321,7 @@ async def test_vllm_async_batch(async_model): @pytest.mark.asyncio async def test_vllm_async_vision(async_model): result = await async_model(["hello", image_input], max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -336,7 +337,7 @@ async def test_vllm_async_vision_chat(async_model): ]), max_tokens=10, ) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -344,27 +345,27 @@ async def test_vllm_async_multiple_samples(async_model): result = await async_model("Respond with a single word.", n=2) assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], str) + assert isinstance(result[0], Output) + assert isinstance(result[1], Output) @pytest.mark.asyncio async def test_vllm_async_json(async_model): json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}}' result = await async_model("foo?", JsonSchema(json_string), max_tokens=10) - assert isinstance(result, str) - assert "bar" in result + assert isinstance(result, Output) + assert "bar" in result.content @pytest.mark.asyncio async def test_vllm_async_regex(async_model): result = await async_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10) - assert isinstance(result, str) - assert re.match(r"[0-9]{3}", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]{3}", result.content) @pytest.mark.asyncio async def test_vllm_async_cfg(async_model): result = await async_model("foo?", CFG(YES_NO_GRAMMAR), max_tokens=10) - assert isinstance(result, str) - assert result in ["yes", "no"] + assert isinstance(result, Output) + assert result.content in ["yes", "no"] diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index e3c43bd50..c139b0238 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -19,6 +19,7 @@ VLLMOfflineTypeAdapter, from_vllm_offline ) +from outlines.outputs import Output, StreamingOutput from outlines.types import Regex @@ -58,13 +59,13 @@ def model(tmp_path_factory): def test_vllm_simple(model): - result = model.generate("Respond with one word. Not more.", None) - assert isinstance(result, str) + result = model("Respond with one word. Not more.", None) + assert isinstance(result, Output) def test_vllm_call(model): result = model("Respond with one word. Not more.") - assert isinstance(result, str) + assert isinstance(result, Output) def test_vllm_inference_kwargs(model): @@ -73,8 +74,8 @@ def test_vllm_inference_kwargs(model): sampling_params=SamplingParams(max_tokens=2), use_tqdm=True ) - assert isinstance(result, str) - assert len(result) <= 20 + assert isinstance(result, Output) + assert len(result.content) <= 20 def test_vllm_chat(model): @@ -86,7 +87,7 @@ def test_vllm_chat(model): ]), sampling_params=SamplingParams(max_tokens=2), ) - assert isinstance(result, str) + assert isinstance(result, Output) def test_vllm_invalid_inference_kwargs(model): @@ -96,8 +97,8 @@ def test_vllm_invalid_inference_kwargs(model): def test_vllm_regex(model): result = model("Give a number between 0 and 9.", Regex(r"[0-9]")) - assert isinstance(result, str) - assert re.match(r"[0-9]", result) + assert isinstance(result, Output) + assert re.match(r"[0-9]", result.content) def test_vllm_json(model): @@ -105,7 +106,7 @@ class Character(BaseModel): name: str result = model("Create a character with a name.", Character) - assert "name" in result + assert "name" in result.content def test_vllm_choice(model): @@ -114,7 +115,7 @@ class Foo(Enum): dog = "dog" result = model("Cat or dog?", Foo) - assert result in ["cat", "dog"] + assert result.content in ["cat", "dog"] def test_vllm_multiple_samples(model): diff --git a/tests/models/test_vllm_offline_type_adapter.py b/tests/models/test_vllm_offline_type_adapter.py index c431230e7..4606396b0 100644 --- a/tests/models/test_vllm_offline_type_adapter.py +++ b/tests/models/test_vllm_offline_type_adapter.py @@ -6,6 +6,7 @@ from outlines.inputs import Chat, Image from outlines.models.vllm_offline import VLLMOfflineTypeAdapter +from outlines.tools import ToolDef from outlines.types import CFG, JsonSchema, Regex @@ -113,3 +114,15 @@ def test_vllm_offline_type_adapter_output_type( assert type_adapter.format_output_type(regex_instance) == { "regex": "([0-9]+)" } + + +def test_vllm_offline_type_adapter_tools(type_adapter): + with pytest.raises( + NotImplementedError, + match="Tools are not available for VLLM offline." + ): + type_adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + type_adapter.format_tools(None) diff --git a/tests/models/test_vllm_type_adapter.py b/tests/models/test_vllm_type_adapter.py index 8f208830b..e03f04e8a 100644 --- a/tests/models/test_vllm_type_adapter.py +++ b/tests/models/test_vllm_type_adapter.py @@ -7,6 +7,7 @@ from outlines.inputs import Chat, Image from outlines.models.vllm import VLLMTypeAdapter +from outlines.tools import ToolDef from outlines.types import CFG, JsonSchema @@ -143,3 +144,15 @@ def test_vllm_type_adapter_output_type( assert type_adapter.format_output_type(int) == { "guided_regex": "([+-]?(0|[1-9][0-9]*))" } + + +def test_vllm_type_adapter_tools(type_adapter): + with pytest.raises( + NotImplementedError, + match="Tools are not available for VLLM." + ): + type_adapter.format_tools( + [ToolDef(name="test", description="test", parameters={})] + ) + + type_adapter.format_tools(None) diff --git a/tests/test_applications.py b/tests/test_applications.py index 5fed7bf23..7071a2bbf 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -6,6 +6,7 @@ from outlines import from_transformers from outlines.applications import Application +from outlines.outputs import Output from outlines.templates import Template @@ -51,7 +52,7 @@ def test_application_template_call(model): application = Application(template, output_type) result = application(model, {"value": "example"}, max_new_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) def test_application_callable_call(model): @@ -62,7 +63,7 @@ def template(value): application = Application(template, output_type) result = application(model, {"value": "example"}, max_new_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) def test_application_template_error(model): diff --git a/tests/test_generator.py b/tests/test_generator.py index e14ca5381..c85778f47 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -13,6 +13,7 @@ AsyncBlackBoxGenerator, ) from outlines.models import AsyncVLLM, VLLM +from outlines.outputs import Output, StreamingOutput from outlines.processors import ( OutlinesLogitsProcessor, ) @@ -113,7 +114,7 @@ def test_steerable_generator_init_invalid_output_type(steerable_model, sample_pr def test_steerable_generator_call(steerable_model): generator = SteerableGenerator(steerable_model, Literal["foo", "bar"]) result = generator("foo", max_new_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) def test_steerable_generator_stream(steerable_model): @@ -121,7 +122,7 @@ def test_steerable_generator_stream(steerable_model): generator = SteerableGenerator(steerable_model, Literal["foo", "bar"]) result = generator.stream("foo", max_tokens=10) assert isinstance(result, TypingGenerator) - assert isinstance(next(result), str) + assert isinstance(next(result), StreamingOutput) # BlackBoxGenerator @@ -135,14 +136,14 @@ def test_black_box_generator_init(black_box_sync_model): def test_black_box_generator_call(black_box_sync_model): generator = BlackBoxGenerator(black_box_sync_model, str) result = generator("Write a very short sentence", max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) def test_black_box_generator_stream(black_box_sync_model): generator = BlackBoxGenerator(black_box_sync_model, str) result = generator.stream("Write a very short sentence", max_tokens=10) assert isinstance(result, TypingGenerator) - assert isinstance(next(result), str) + assert isinstance(next(result), StreamingOutput) # AsyncBlackBoxGenerator @@ -158,7 +159,7 @@ def test_async_black_box_generator_init(black_box_async_model): async def test_async_black_box_generator_call(black_box_async_model): generator = AsyncBlackBoxGenerator(black_box_async_model, str) result = await generator("Write a very short sentence", max_tokens=10) - assert isinstance(result, str) + assert isinstance(result, Output) @pytest.mark.asyncio @@ -167,7 +168,7 @@ async def test_async_black_box_generator_stream(black_box_async_model): result = generator.stream("Write a very short sentence", max_tokens=10) assert isinstance(result, AsyncGenerator) async for chunk in result: - assert isinstance(chunk, str) + assert isinstance(chunk.content, str) break # Just check the first chunk diff --git a/tests/test_inputs.py b/tests/test_inputs.py index d020a48c9..b2f81641b 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -1,14 +1,13 @@ """Unit tests for the inputs module.""" import base64 -import tempfile from io import BytesIO -from typing import Dict, List, Any import pytest from PIL import Image as PILImage from outlines.inputs import Image, Video, Audio, Chat +from outlines.outputs import ToolCallOutput, Output @pytest.fixture @@ -20,6 +19,15 @@ def image_input(): return Image(image=image) +@pytest.fixture +def tool_call(): + return { + "tool_name": "foo", + "tool_call_id": "abc", + "args": {"bar": 1} + } + + def test_image_initialization(): # png image = PILImage.new("RGB", (100, 100), color="red") @@ -98,6 +106,7 @@ def test_chat_append(): assert len(chat.messages) == 1 assert chat.messages[0] == message + def test_chat_extend(): chat = Chat(messages=[]) messages = [ @@ -108,6 +117,7 @@ def test_chat_extend(): assert len(chat.messages) == 2 assert chat.messages == messages + def test_chat_pop(): # Pop from non-empty chat messages = [ @@ -126,31 +136,23 @@ def test_chat_pop(): chat.pop() -def test_chat_add_system_message(image_input): - # Add a string +def test_chat_add_system_message(): chat = Chat(messages=[]) chat.add_system_message("You are a helpful assistant.") assert len(chat.messages) == 1 assert chat.messages[0]["role"] == "system" assert chat.messages[0]["content"] == "You are a helpful assistant." - # Add a list - chat = Chat(messages=[]) - chat.add_system_message(["prompt", image_input]) - assert len(chat.messages) == 1 - assert chat.messages[0]["role"] == "system" - assert chat.messages[0]["content"] == ["prompt", image_input] - -def test_add_user_message_string(image_input): - # Add a string +def test_add_user_message(image_input): + # String content chat = Chat(messages=[]) chat.add_user_message("Hello, how are you?") assert len(chat.messages) == 1 assert chat.messages[0]["role"] == "user" assert chat.messages[0]["content"] == "Hello, how are you?" - # Add a list + # List content chat = Chat(messages=[]) chat.add_user_message(["prompt", image_input]) assert len(chat.messages) == 1 @@ -158,17 +160,56 @@ def test_add_user_message_string(image_input): assert chat.messages[0]["content"] == ["prompt", image_input] -def test_add_assistant_message_string(image_input): - # Add a string +def test_add_assistant_message(image_input, tool_call): + # String content chat = Chat(messages=[]) chat.add_assistant_message("I'm doing well, thank you!") assert len(chat.messages) == 1 assert chat.messages[0]["role"] == "assistant" assert chat.messages[0]["content"] == "I'm doing well, thank you!" - # Add a list + # List content chat = Chat(messages=[]) chat.add_assistant_message(["prompt", image_input]) assert len(chat.messages) == 1 assert chat.messages[0]["role"] == "assistant" assert chat.messages[0]["content"] == ["prompt", image_input] + + # Tool calls + chat = Chat() + chat.add_assistant_message("hello", tool_calls=[tool_call]) + assert len(chat.messages) == 1 + assert chat.messages[0]["role"] == "assistant" + assert chat.messages[0]["content"] == "hello" + assert chat.messages[0]["tool_calls"] == [tool_call] + + +def test_add_tool_message(): + chat = Chat() + chat.add_tool_message("response", tool_call_id="abc", tool_name="foo") + assert len(chat.messages) == 1 + assert chat.messages[0]["role"] == "tool" + assert chat.messages[0]["content"] == "response" + assert chat.messages[0]["tool_call_id"] == "abc" + assert chat.messages[0]["tool_name"] == "foo" + + +def test_add_output(tool_call): + # Without tool calls + output = Output(content="response") + chat = Chat() + chat.add_output(output) + assert len(chat.messages) == 1 + assert chat.messages[0]["role"] == "assistant" + assert chat.messages[0]["content"] == "response" + assert chat.messages[0]["tool_calls"] is None + + # With tool calls + tool_call_output = ToolCallOutput(name="foo", args={"bar": 1}, id="abc") + output = Output(content="response", tool_calls=[tool_call_output]) + chat = Chat() + chat.add_output(output) + assert len(chat.messages) == 1 + assert chat.messages[0]["role"] == "assistant" + assert chat.messages[0]["content"] == "response" + assert chat.messages[0]["tool_calls"] == [tool_call] diff --git a/tests/test_outputs.py b/tests/test_outputs.py new file mode 100644 index 000000000..2d3efadd3 --- /dev/null +++ b/tests/test_outputs.py @@ -0,0 +1,48 @@ +import pytest +from outlines.outputs import Output, StreamingOutput, ToolCallOutput, StreamingToolCallOutput + + +def test_tool_call_output(): + tool_call = ToolCallOutput( + name="test_tool", + args={"param": "value"}, + id="call_123" + ) + assert tool_call.name == "test_tool" + assert tool_call.args == {"param": "value"} + assert tool_call.id == "call_123" + + +def test_streaming_tool_call_output(): + tool_call = StreamingToolCallOutput( + name="streaming_tool", + args="partial_args", + id="stream_456" + ) + assert tool_call.name == "streaming_tool" + assert tool_call.args == "partial_args" + assert tool_call.id == "stream_456" + + +def test_output(): + tool_calls = [ToolCallOutput(name="test", args={"arg": "value"})] + output = Output(content="Hello", tool_calls=tool_calls) + + assert output.content == "Hello" + assert output.tool_calls == tool_calls + assert str(output) == "Hello" + assert "Output(" in repr(output) + assert output + " World" == "Hello World" + assert "Hi " + output == "Hi Hello" + + +def test_streaming_output(): + tool_calls = [StreamingToolCallOutput(name="stream", args="partial")] + output = StreamingOutput(content="Streaming", tool_calls=tool_calls) + + assert output.content == "Streaming" + assert output.tool_calls == tool_calls + assert str(output) == "Streaming" + assert "StreamingOutput(" in repr(output) + assert output + " content" == "Streaming content" + assert "Live " + output == "Live Streaming" diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 000000000..0f7b11877 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,278 @@ +import pytest +from typing import Optional, Union +from pydantic import BaseModel + +from outlines.tools import ( + get_formatted_tools, + _callable_to_tool_def, + _pydantic_model_to_tool_def, + _type_to_string, +) + + +def test_get_formatted_tools_none(): + result = get_formatted_tools(None) + assert result is None + + +def test_get_formatted_tools_empty_list(): + result = get_formatted_tools([]) + assert result is None + + +def test_get_formatted_tools_tool_def(): + tool_def = { + "name": "test_tool", + "description": "A test tool", + "parameters": {"param1": {"type": "string"}}, + "required": ["param1"] + } + result = get_formatted_tools([tool_def]) + assert result == [tool_def] + + +def test_get_formatted_tools_invalid_tool_def(): + invalid_tool_def = { + "name": "test_tool", + "description": "A test tool" + } + with pytest.raises(ValueError, match="Invalid ToolDef"): + get_formatted_tools([invalid_tool_def]) + + +def test_get_formatted_tools_callable(): + def test_function(param1: str, param2: int = 5) -> str: + """A test function.""" + return f"{param1}_{param2}" + + result = get_formatted_tools([test_function]) + expected = { + "name": "test_function", + "description": "A test function.", + "parameters": { + "param1": {"type": "string"}, + "param2": {"type": "integer"} + }, + "required": ["param1"] + } + assert result == [expected] + + +def test_get_formatted_tools_pydantic(): + class TestModel(BaseModel): + """A test model.""" + field1: str + field2: int = 10 + + result = get_formatted_tools([TestModel]) + expected = { + "name": "TestModel", + "description": "A test model.", + "parameters": { + "field1": {"type": "string"}, + "field2": {"type": "integer"} + }, + "required": ["field1"] + } + assert result == [expected] + + +def test_get_formatted_tools_mixed(): + def test_func(param: str) -> str: + """Test function.""" + return param + + class TestModel(BaseModel): + field: str + + tool_def = { + "name": "dict_tool", + "description": "Dict tool", + "parameters": {"param": {"type": "string"}}, + "required": ["param"] + } + + result = get_formatted_tools([test_func, TestModel, tool_def]) + assert len(result) == 3 + assert result[0]["name"] == "test_func" + assert result[1]["name"] == "TestModel" + assert result[2]["name"] == "dict_tool" + + +def test_get_formatted_tools_unsupported(): + with pytest.raises(ValueError, match="Unsupported tool type"): + get_formatted_tools([123]) # int is not supported + + +def test_callable_to_tool_def_with_docfunction_withstring(): + def test_func(param1: str, param2: int) -> str: + """This is a test function.""" + return f"{param1}_{param2}" + + result = _callable_to_tool_def(test_func) + expected = { + "name": "test_func", + "description": "This is a test function.", + "parameters": { + "param1": {"type": "string"}, + "param2": {"type": "integer"} + }, + "required": ["param1", "param2"] + } + assert result == expected + + +def test_callable_to_tool_def_without_docstring(): + def test_func(param1: str) -> str: + return param1 + + result = _callable_to_tool_def(test_func) + expected = { + "name": "test_func", + "description": "Function test_func", + "parameters": { + "param1": {"type": "string"} + }, + "required": ["param1"] + } + assert result == expected + + +def test_callable_to_tool_def_with_defaults(): + def test_func(param1: str, param2: int = 5, param3: bool = True) -> str: + """Test function with defaults.""" + return f"{param1}_{param2}_{param3}" + + result = _callable_to_tool_def(test_func) + expected = { + "name": "test_func", + "description": "Test function with defaults.", + "parameters": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + "param3": {"type": "boolean"} + }, + "required": ["param1"] + } + assert result == expected + + +def test_callable_to_tool_def_without_annotations(): + def test_func(param1, param2=5): + return f"{param1}_{param2}" + + with pytest.raises( + ValueError, + match="All parameters must have an annotation.", + ): + _callable_to_tool_def(test_func) + + +def test_callable_to_tool_def_with_union_type(): + def test_func(param1: Union[str, int]) -> str: + """Test function with Union type.""" + return str(param1) + + result = _callable_to_tool_def(test_func) + expected = { + "name": "test_func", + "description": "Test function with Union type.", + "parameters": { + "param1": {"type": "string|integer"} + }, + "required": ["param1"] + } + assert result == expected + + +def test_pydantic_model_to_tool_def_with_description(): + class TestModel(BaseModel): + """A test model with description.""" + field1: str + field2: int = 10 + + result = _pydantic_model_to_tool_def(TestModel) + expected = { + "name": "TestModel", + "description": "A test model with description.", + "parameters": { + "field1": {"type": "string"}, + "field2": {"type": "integer"} + }, + "required": ["field1"] + } + assert result == expected + + +def test_pydantic_model_to_tool_def_without_description(): + class TestModel(BaseModel): + field1: str + + result = _pydantic_model_to_tool_def(TestModel) + expected = { + "name": "TestModel", + "description": "Model TestModel", + "parameters": { + "field1": {"type": "string"} + }, + "required": ["field1"] + } + assert result == expected + + +def test_pydantic_model_to_tool_def_with_union_type(): + class TestModel(BaseModel): + string_field: str + int_field: int + float_field: float + bool_field: bool + optional_field: Optional[str] = None + + result = _pydantic_model_to_tool_def(TestModel) + expected = { + "name": "TestModel", + "description": "Model TestModel", + "parameters": { + "string_field": {"type": "string"}, + "int_field": {"type": "integer"}, + "float_field": {"type": "number"}, + "bool_field": {"type": "boolean"}, + "optional_field": {"type": "string"} + }, + "required": ["string_field", "int_field", "float_field", "bool_field"] + } + assert result == expected + + +def test_type_to_string_basic_types(): + assert _type_to_string(str) == "string" + assert _type_to_string(int) == "integer" + assert _type_to_string(float) == "number" + assert _type_to_string(bool) == "boolean" + assert _type_to_string(list) == "array" + assert _type_to_string(dict) == "object" + + +def test_type_to_string_union_types(): + # Test Union[str, int] + union_type = Union[str, int] + result = _type_to_string(union_type) + assert result == "string|integer" + + # Test Optional[str] + optional_type = Optional[str] + result = _type_to_string(optional_type) + assert result == "string|null" + + # Test Union[str, int, None] + union_with_none = Union[str, int, None] + result = _type_to_string(union_with_none) + assert result == "string|integer|null" + + +def test_type_to_string_unsupported_type(): + class CustomType: + pass + + with pytest.raises(ValueError, match="Unsupported type"): + _type_to_string(CustomType)