15
15
16
16
import asyncio
17
17
import types
18
- from collections import defaultdict
19
18
from inspect import Parameter , Signature
20
19
from typing import (
21
20
Any ,
22
21
Callable ,
23
- DefaultDict ,
24
22
Iterable ,
25
23
Mapping ,
26
24
Optional ,
29
27
)
30
28
31
29
from aiohttp import ClientSession
32
- from pytest import Session
33
-
34
30
35
31
class ToolboxTool :
36
32
"""
@@ -52,6 +48,7 @@ def __init__(
52
48
name : str ,
53
49
desc : str ,
54
50
params : Sequence [Parameter ],
51
+ params_metadata : Mapping [str , tuple [str , str ]],
55
52
required_authn_params : Mapping [str , list [str ]],
56
53
auth_service_token_getters : Mapping [str , Callable [[], str ]],
57
54
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
@@ -83,10 +80,11 @@ def __init__(
83
80
84
81
self .__desc = desc
85
82
self .__params = params
83
+ self .__params_metadata = params_metadata
86
84
87
85
# the following properties are set to help anyone that might inspect it determine usage
88
86
self .__name__ = name
89
- self .__doc__ = desc
87
+ self .__doc__ = self . _schema_to_docstring ( desc , params , params_metadata )
90
88
self .__signature__ = Signature (parameters = params , return_annotation = str )
91
89
self .__annotations__ = {p .name : p .annotation for p in params }
92
90
# TODO: self.__qualname__ ??
@@ -98,6 +96,21 @@ def __init__(
98
96
# map of parameter name to value (or callable that produces that value)
99
97
self .__bound_parameters = bound_params
100
98
99
+ @staticmethod
100
+ def _schema_to_docstring (
101
+ tool_description : str ,
102
+ params : Sequence [Parameter ],
103
+ params_metadata : Mapping [str , tuple [str , str ]],
104
+ ):
105
+ docstring = tool_description
106
+ if not params :
107
+ return docstring
108
+ docstring += "\n \n Args:"
109
+ for p in params :
110
+ param_metadata = params_metadata [p .name ]
111
+ docstring += f"\n { p .name } ({ param_metadata [0 ]} ): { param_metadata [1 ]} "
112
+ return docstring
113
+
101
114
def __copy (
102
115
self ,
103
116
session : Optional [ClientSession ] = None ,
@@ -134,6 +147,7 @@ def __copy(
134
147
name = check (name , self .__name__ ),
135
148
desc = check (desc , self .__desc ),
136
149
params = check (params , self .__params ),
150
+ params_metadata = self .__params_metadata ,
137
151
required_authn_params = check (
138
152
required_authn_params , self .__required_authn_params
139
153
),
0 commit comments