Skip to content

Commit 988ea4c

Browse files
committed
Implement tools and outputs for the Dottxt model
1 parent fcfb303 commit 988ea4c

File tree

3 files changed

+46
-12
lines changed

3 files changed

+46
-12
lines changed

outlines/models/dottxt.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Integration with Dottxt's API."""
22

33
import json
4-
from typing import TYPE_CHECKING, Any, Optional
4+
from typing import TYPE_CHECKING, Any, Optional, List
55

66
from pydantic import TypeAdapter
77

88
from outlines.models.base import Model, ModelTypeAdapter
9+
from outlines.outputs import Output
10+
from outlines.tools import ToolDef
911
from outlines.types import CFG, JsonSchema, Regex
1012
from outlines.types.utils import (
1113
is_dataclass,
@@ -44,7 +46,7 @@ def format_input(self, model_input: str) -> str:
4446
"The only available type is `str`."
4547
)
4648

47-
def format_output_type(self, output_type: Optional[Any] = None) -> str:
49+
def format_output_type(self, output_type: Optional[Any]) -> str:
4850
"""Format the output type to pass to the client.
4951
5052
TODO: `int`, `float` and other Python types could be supported via
@@ -98,6 +100,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> str:
98100
"Consider using a local mode instead."
99101
)
100102

103+
def format_tools(self, tools):
104+
"""Not implemented for Dottxt."""
105+
if tools:
106+
raise NotImplementedError(
107+
"Dottxt does not support tools."
108+
)
109+
101110

102111
class Dottxt(Model):
103112
"""Thin wrapper around the `dottxt.client.Dottxt` client.
@@ -132,9 +141,10 @@ def __init__(
132141
def generate(
133142
self,
134143
model_input: str,
135-
output_type: Optional[Any] = None,
144+
output_type: Optional[Any],
145+
tools: Optional[List[ToolDef]],
136146
**inference_kwargs: Any,
137-
) -> str:
147+
) -> Output:
138148
"""Generate text using Dottxt.
139149
140150
Parameters
@@ -145,15 +155,18 @@ def generate(
145155
The desired format of the response generated by the model. The
146156
output type must be of a type that can be converted to a JSON
147157
schema.
158+
tools
159+
The tools to use for the generation.
148160
**inference_kwargs
149161
Additional keyword arguments to pass to the client.
150162
151163
Returns
152164
-------
153-
str
165+
Output
154166
The text generated by the model.
155167
156168
"""
169+
self.type_adapter.format_tools(tools)
157170
prompt = self.type_adapter.format_input(model_input)
158171
json_schema = self.type_adapter.format_output_type(output_type)
159172

@@ -174,22 +187,26 @@ def generate(
174187
json_schema,
175188
**inference_kwargs,
176189
)
177-
return completion.data
190+
191+
return Output(content=completion.data)
178192

179193
def generate_batch(
180194
self,
181195
model_input,
182-
output_type = None,
196+
output_type,
197+
tools,
183198
**inference_kwargs,
184199
):
200+
"""Not available for Dottxt."""
185201
raise NotImplementedError(
186202
"Dottxt does not support batch generation."
187203
)
188204

189205
def generate_stream(
190206
self,
191207
model_input,
192-
output_type=None,
208+
output_type,
209+
tools,
193210
**inference_kwargs,
194211
):
195212
"""Not available for Dottxt."""

tests/models/test_dottxt.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import outlines
99
from outlines import Generator
1010
from outlines.models.dottxt import Dottxt
11+
from outlines.outputs import Output
1112

1213

1314
MODEL_NAME = "dottxt/dottxt-v1-alpha"
@@ -99,7 +100,8 @@ def test_dottxt_wrong_inference_parameters(model_no_model_name):
99100
@pytest.mark.api_call
100101
def test_dottxt_direct_pydantic_call(model_no_model_name):
101102
result = model_no_model_name("Create a user", User)
102-
assert "first_name" in json.loads(result)
103+
assert isinstance(result, Output)
104+
assert "first_name" in json.loads(result.content)
103105

104106

105107
@pytest.mark.api_call
@@ -112,14 +114,16 @@ def test_dottxt_direct_jsonschema_call(
112114
model_name=model_name_and_revision[0],
113115
model_revision=model_name_and_revision[1],
114116
)
115-
assert "first_name" in json.loads(result)
117+
assert isinstance(result, Output)
118+
assert "first_name" in json.loads(result.content)
116119

117120

118121
@pytest.mark.api_call
119122
def test_dottxt_generator_pydantic_call(model):
120123
generator = Generator(model, User)
121124
result = generator("Create a user")
122-
assert "first_name" in json.loads(result)
125+
assert isinstance(result, Output)
126+
assert "first_name" in json.loads(result.content)
123127

124128

125129
@pytest.mark.api_call

tests/models/test_dottxt_type_adapter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from outlines.inputs import Image
1212
from outlines.models.dottxt import DottxtTypeAdapter
13+
from outlines.tools import ToolDef
1314
from outlines.types import cfg, json_schema, regex
1415

1516
if sys.version_info >= (3, 12):
@@ -58,7 +59,7 @@ def test_dottxt_type_adapter_input_text(adapter):
5859

5960

6061
def test_dottxt_type_adapter_input_invalid(adapter, image):
61-
prompt = ["prompt", image]
62+
prompt = ["prompt", Image(image)]
6263
with pytest.raises(TypeError, match="The input type"):
6364
_ = adapter.format_input(prompt)
6465

@@ -135,3 +136,15 @@ def test_dottxt_type_adapter_json_schema_str(adapter, schema):
135136
def test_dottxt_type_adapter_json_schema_dict(adapter, schema):
136137
result = adapter.format_output_type(json_schema(schema))
137138
assert result == json.dumps(schema)
139+
140+
141+
def test_dottxt_type_adapter_tools(adapter):
142+
with pytest.raises(
143+
NotImplementedError,
144+
match="Dottxt does not support tools."
145+
):
146+
adapter.format_tools(
147+
[ToolDef(name="test", description="test", parameters={})]
148+
)
149+
150+
adapter.format_tools(None)

0 commit comments

Comments
 (0)