Skip to content

Commit 66bac1f

Browse files
committed
Fix default converters
1 parent 32d8c99 commit 66bac1f

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

twitchio/ext/commands/core.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __eq__(self, __value: object) -> bool:
6363
EMPTY = EmptyArgumentSentinel()
6464

6565

66-
def _boolconverter(param: str):
66+
def _boolconverter(_, param: str):
6767
param = param.lower()
6868
if param in {"yes", "y", "1", "true", "on"}:
6969
return True
@@ -147,10 +147,9 @@ def resolve_union_callback(self, name: str, converter: UnionT) -> Callable[[Cont
147147

148148
async def _resolve(context: Context, arg: str) -> Any:
149149
t = EMPTY
150-
last = None
151150

152151
for original in args:
153-
underlying = self._resolve_converter(name, original)
152+
underlying = self._resolve_converter(name, original, context)
154153

155154
try:
156155
t: Any = underlying(context, arg)
@@ -159,7 +158,6 @@ async def _resolve(context: Context, arg: str) -> Any:
159158

160159
break
161160
except Exception as l:
162-
last = l
163161
t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back
164162
continue
165163

@@ -170,8 +168,8 @@ async def _resolve(context: Context, arg: str) -> Any:
170168

171169
return _resolve
172170

173-
def resolve_optional_callback(self, name: str, converter: Any) -> Callable[[Context, str], Any]:
174-
underlying = self._resolve_converter(name, converter.__args__[0])
171+
def resolve_optional_callback(self, name: str, converter: Any, context: Context) -> Callable[[Context, str], Any]:
172+
underlying = self._resolve_converter(name, converter.__args__[0], context)
175173

176174
async def _resolve(context: Context, arg: str) -> Any:
177175
try:
@@ -186,7 +184,7 @@ async def _resolve(context: Context, arg: str) -> Any:
186184

187185
return _resolve
188186

189-
def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]:
187+
def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, type], ctx: Context) -> Callable[..., Any]:
190188
if (
191189
isinstance(converter, type)
192190
and converter.__module__.startswith("twitchio")
@@ -198,26 +196,27 @@ def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, ty
198196
converter = self._convert_builtin_type(name, bool, _boolconverter)
199197

200198
elif converter in (str, int):
201-
converter = self._convert_builtin_type(name, converter, converter) # type: ignore
199+
original: type[str | int] = converter # type: ignore
200+
converter = self._convert_builtin_type(name, original, lambda _, arg: original(arg))
202201

203202
elif self._is_optional_argument(converter):
204-
return self.resolve_optional_callback(name, converter)
203+
return self.resolve_optional_callback(name, converter, ctx)
205204

206205
elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union:
207206
return self.resolve_union_callback(name, converter) # type: ignore
208207

209208
elif hasattr(converter, "__metadata__"): # Annotated
210209
annotated = converter.__metadata__ # type: ignore
211-
return self._resolve_converter(name, annotated[0])
210+
return self._resolve_converter(name, annotated[0], ctx)
212211

213212
return converter # type: ignore
214213

215214
def _convert_builtin_type(
216-
self, arg_name: str, original: type, converter: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]]
215+
self, arg_name: str, original: type, converter: Union[Callable[[Context, str], Any], Callable[[Context, str], Awaitable[Any]]]
217216
) -> Callable[[Context, str], Awaitable[Any]]:
218-
async def resolve(_, arg: str) -> Any:
217+
async def resolve(ctx, arg: str) -> Any:
219218
try:
220-
t = converter(arg)
219+
t = converter(ctx, arg)
221220

222221
if inspect.iscoroutine(t):
223222
t = await t
@@ -242,7 +241,7 @@ async def _convert_types(self, context: Context, param: inspect.Parameter, parse
242241
else:
243242
converter = type(param.default)
244243

245-
true_converter = self._resolve_converter(param.name, converter)
244+
true_converter = self._resolve_converter(param.name, converter, context)
246245

247246
try:
248247
argument = true_converter(context, parsed)

0 commit comments

Comments
 (0)