Skip to content

Commit bfe4c45

Browse files
authored
Fix support for applying tool_method to pydantic base model classes. Migrate a bunch of optional import logic to be lazy. Clean up dependencies and typing for optional packages. (#237)
1 parent e98af69 commit bfe4c45

File tree

20 files changed

+481
-288
lines changed

20 files changed

+481
-288
lines changed

.hooks/generate_docs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import typing as t
44
from pathlib import Path
55

6-
from markdown import Markdown # type: ignore[import-untyped]
7-
from markdownify import MarkdownConverter # type: ignore[import-untyped]
6+
from markdown import Markdown # type: ignore [import-untyped]
7+
from markdownify import MarkdownConverter # type: ignore [import-untyped]
88
from markupsafe import Markup
99
from mkdocstrings_handlers.python._internal.config import PythonConfig
1010
from mkdocstrings_handlers.python._internal.handler import (
@@ -14,7 +14,7 @@
1414
# ruff: noqa: T201
1515

1616

17-
class CustomMarkdownConverter(MarkdownConverter): # type: ignore[misc]
17+
class CustomMarkdownConverter(MarkdownConverter): # type: ignore [misc]
1818
# Strip extra whitespace from code blocks
1919
def convert_pre(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any:
2020
return super().convert_pre(el, text.strip(), parent_tags)

docs/api/data.mdx

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ chat\_id column.
5858

5959
<Accordion title="Source code in rigging/data.py" icon="code">
6060
```python
61-
def chats_to_df(chats: Chat | t.Sequence[Chat]) -> pd.DataFrame:
61+
def chats_to_df(chats: Chat | t.Sequence[Chat]) -> "pd.DataFrame":
6262
"""
6363
Convert a Chat or list of Chat objects into a pandas DataFrame.
6464
@@ -73,6 +73,13 @@ def chats_to_df(chats: Chat | t.Sequence[Chat]) -> pd.DataFrame:
7373
A pandas DataFrame containing the chat data.
7474
7575
"""
76+
try:
77+
import pandas as pd
78+
except ImportError as e:
79+
raise ImportError(
80+
"Pandas is not available. Please install `pandas` or use `rigging[data]`.",
81+
) from e
82+
7683
chats = [chats] if isinstance(chats, Chat) else chats
7784

7885
flattened = flatten_chats(chats)
@@ -176,10 +183,10 @@ async def chats_to_elastic(
176183
The indexed count from the bulk operation
177184
"""
178185
try:
179-
import elasticsearch.helpers
186+
import elasticsearch.helpers # type: ignore [import-not-found, unused-ignore]
180187
except ImportError as e:
181188
raise ImportError(
182-
"Elasticsearch is not available. Please install `elasticsearch` or use `rigging[extra]`.",
189+
"Elasticsearch is not available. Please install `elasticsearch` or use `rigging[data]`.",
183190
) from e
184191

185192
es_data = chats_to_elastic_data(chats, index, op_type=op_type)
@@ -190,7 +197,7 @@ async def chats_to_elastic(
190197
await client.indices.put_mapping(index=index, properties=ElasticMapping["properties"])
191198

192199
results = await elasticsearch.helpers.async_bulk(client, es_data, **kwargs)
193-
return results[0] # Return modified count
200+
return results[0] # type: ignore [no-any-return, unused-ignore]
194201
```
195202

196203

@@ -286,7 +293,7 @@ generated by the `chats_to_df` function.
286293

287294
<Accordion title="Source code in rigging/data.py" icon="code">
288295
```python
289-
def df_to_chats(df: pd.DataFrame) -> list[Chat]:
296+
def df_to_chats(df: "pd.DataFrame") -> list[Chat]:
290297
"""
291298
Convert a pandas DataFrame into a list of Chat objects.
292299
@@ -301,6 +308,7 @@ def df_to_chats(df: pd.DataFrame) -> list[Chat]:
301308
A list of Chat objects.
302309
303310
"""
311+
304312
chats = []
305313
for chat_id, chat_group in df.groupby("chat_id"):
306314
chat_data = chat_group.iloc[0]
@@ -564,7 +572,7 @@ Determine if an S3 bucket exists.
564572

565573
<Accordion title="Source code in rigging/data.py" icon="code">
566574
```python
567-
async def s3_bucket_exists(client: S3Client, bucket: str) -> bool:
575+
async def s3_bucket_exists(client: "S3Client", bucket: str) -> bool:
568576
"""
569577
Determine if an S3 bucket exists.
570578
@@ -618,7 +626,7 @@ Determine if an S3 object exists.
618626

619627
<Accordion title="Source code in rigging/data.py" icon="code">
620628
```python
621-
async def s3_object_exists(client: S3Client, bucket: str, key: str) -> bool:
629+
async def s3_object_exists(client: "S3Client", bucket: str, key: str) -> bool:
622630
"""
623631
Determine if an S3 object exists.
624632

docs/api/logging.mdx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ configure\_logging
3232

3333
```python
3434
configure_logging(
35-
log_level: LogLevelLiteral,
35+
log_level: LogLevelLiteral = "info",
3636
log_file: Path | None = None,
3737
log_file_level: LogLevelLiteral = "debug",
3838
) -> None
@@ -43,7 +43,9 @@ Configures common loguru handlers.
4343
**Parameters:**
4444

4545
* **`log_level`**
46-
(`LogLevelLiteral`)
46+
(`LogLevelLiteral`, default:
47+
`'info'`
48+
)
4749
–The desired log level.
4850
* **`log_file`**
4951
(`Path | None`, default:
@@ -60,7 +62,7 @@ Configures common loguru handlers.
6062
<Accordion title="Source code in rigging/logging.py" icon="code">
6163
```python
6264
def configure_logging(
63-
log_level: LogLevelLiteral,
65+
log_level: LogLevelLiteral = "info",
6466
log_file: pathlib.Path | None = None,
6567
log_file_level: LogLevelLiteral = "debug",
6668
) -> None:

docs/api/tools.mdx

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@ Base class for representing a tool to a generator.
4141
### catch
4242

4343
```python
44-
catch: bool | set[type[Exception]] = {
45-
JSONDecodeError,
46-
ValidationError,
47-
}
44+
catch: bool | set[type[Exception]] = set(
45+
DEFAULT_CATCH_EXCEPTIONS
46+
)
4847
```
4948

5049
Whether to catch exceptions and return them as messages.
@@ -207,6 +206,8 @@ async def handle_tool_call( # noqa: PLR0912
207206
kwargs = json.loads(tool_call.function.arguments)
208207
if self._type_adapter is not None:
209208
kwargs = self._type_adapter.validate_python(kwargs)
209+
kwargs = kwargs or {}
210+
210211
dn.log_inputs(**kwargs)
211212

212213
# Call the function
@@ -367,7 +368,51 @@ def with_(
367368
ToolMethod
368369
----------
369370

370-
A Tool wrapping a class method.
371+
```python
372+
ToolMethod(
373+
fget: Callable[..., Any],
374+
name: str,
375+
description: str,
376+
parameters_schema: dict[str, Any],
377+
catch: bool | Iterable[type[Exception]] | None,
378+
truncate: int | None,
379+
signature: Signature,
380+
type_adapter: TypeAdapter[Any],
381+
)
382+
```
383+
384+
A descriptor that acts as a factory for creating bound Tool instances.
385+
386+
It inherits from `property` to be ignored by pydantic's `ModelMetaclass`
387+
during field inspection. This prevents validation errors which would
388+
otherwise treat the descriptor as a field and stop tool\_method decorators
389+
from being applied in BaseModel classes.
390+
391+
<Accordion title="Source code in rigging/tools/base.py" icon="code">
392+
```python
393+
def __init__(
394+
self,
395+
fget: t.Callable[..., t.Any],
396+
name: str,
397+
description: str,
398+
parameters_schema: dict[str, t.Any],
399+
catch: bool | t.Iterable[type[Exception]] | None,
400+
truncate: int | None,
401+
signature: inspect.Signature,
402+
type_adapter: TypeAdapter[t.Any],
403+
):
404+
super().__init__(fget)
405+
self.tool_name = name
406+
self.tool_description = description
407+
self.tool_parameters_schema = parameters_schema
408+
self.tool_catch = catch
409+
self.tool_truncate = truncate
410+
self._tool_signature = signature
411+
self._tool_type_adapter = type_adapter
412+
```
413+
414+
415+
</Accordion>
371416

372417
tool
373418
----
@@ -621,41 +666,38 @@ def tool_method(
621666
~~~
622667
"""
623668

624-
def make_tool(func: t.Callable[..., t.Any]) -> ToolMethod[P, R]:
625-
# TODO: Improve consistency of detection here before enabling this warning
626-
# if not _is_unbound_method(func):
627-
# warnings.warn(
628-
# "Passing a regular function to @tool_method improperly handles the 'self' argument, use @tool instead.",
629-
# SyntaxWarning,
630-
# stacklevel=3,
631-
# )
669+
def make_tool(f: t.Callable[t.Concatenate[t.Any, P], R]) -> ToolMethod[P, R]:
670+
# This logic is specialized from `Tool.from_callable` to correctly
671+
# handle the `self` parameter in method signatures.
632672

633-
# Strip the `self` argument from the function signature so
634-
# our schema generation doesn't include it under the hood.
673+
signature = inspect.signature(f)
674+
params_without_self = [p for p_name, p in signature.parameters.items() if p_name != "self"]
675+
schema_signature = signature.replace(parameters=params_without_self)
635676

636-
@functools.wraps(func)
637-
def wrapper(self: t.Any, *args: P.args, **kwargs: P.kwargs) -> R:
638-
return func(self, *args, **kwargs) # type: ignore [no-any-return]
677+
@functools.wraps(f)
678+
def empty_func(*_: t.Any, **kwargs: t.Any) -> t.Any:
679+
return kwargs
639680

640-
wrapper.__signature__ = inspect.signature(func).replace( # type: ignore [attr-defined]
641-
parameters=tuple(
642-
param
643-
for param in inspect.signature(func).parameters.values()
644-
if param.name != "self"
645-
),
646-
)
681+
empty_func.__signature__ = schema_signature # type: ignore [attr-defined]
682+
type_adapter: TypeAdapter[t.Any] = TypeAdapter(empty_func)
683+
schema = deref_json(type_adapter.json_schema(), is_json_schema=True)
647684

648-
return ToolMethod.from_callable(
649-
wrapper, # type: ignore [arg-type]
650-
name=name,
651-
description=description,
685+
tool_name = name or f.__name__
686+
tool_description = inspect.cleandoc(description or f.__doc__ or "")
687+
688+
return ToolMethod(
689+
fget=f,
690+
name=tool_name,
691+
description=tool_description,
692+
parameters_schema=schema,
652693
catch=catch,
653694
truncate=truncate,
695+
signature=schema_signature,
696+
type_adapter=type_adapter,
654697
)
655698

656699
if func is not None:
657700
return make_tool(func)
658-
659701
return make_tool
660702
```
661703

@@ -677,7 +719,7 @@ A client for communicating with MCP servers.
677719

678720
<Accordion title="Source code in rigging/tools/mcp.py" icon="code">
679721
```python
680-
def __init__(self, transport: Transport, connection: StdioConnection | SSEConnection) -> None:
722+
def __init__(self, transport: Transport, connection: "StdioConnection | SSEConnection") -> None:
681723
self.transport = transport
682724
self.connection = connection
683725
self.tools = []
@@ -793,6 +835,9 @@ def as_mcp(
793835
)
794836
~~~
795837
"""
838+
from mcp.server.fastmcp import FastMCP
839+
from mcp.server.fastmcp.tools import Tool as FastMCPTool
840+
796841
rigging_tools: list[Tool[..., t.Any]] = []
797842
for tool in flatten_list(list(tools)):
798843
interior_tools = [

docs/api/watchers.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ Create a watcher to write each chat to an Amazon S3 bucket.
343343
<Accordion title="Source code in rigging/watchers.py" icon="code">
344344
```python
345345
def write_chats_to_s3(
346-
client: S3Client,
346+
client: "S3Client",
347347
bucket: str,
348348
key: str,
349349
*,

0 commit comments

Comments
 (0)