Skip to content

Commit da1fed8

Browse files
committed
🐛 fix: fix pep695 union import
1 parent b414a84 commit da1fed8

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

pyfuture/codemod/pep622/match.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,21 @@ def test5():
252252
test_value = 123
253253
if test_value == y:
254254
print("other")
255+
256+
>>> module = cst.parse_module(\"""
257+
... def test6():
258+
... for i in range(2):
259+
... match i:
260+
... case 0:
261+
... yield 0
262+
... case 1:
263+
... yield 1
264+
... \""")
265+
>>> new_module = transformer.transform_module(module)
255266
"""
256267

268+
# TODO(zrr1999): Need to support nested
269+
257270
METADATA_DEPENDENCIES = (ScopeProvider,)
258271

259272
def __init__(self, context: CodemodContext) -> None:

pyfuture/codemod/pep695/type_parameters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,14 @@ def test(x: __test_T) -> __test_T:
5252
... \""")
5353
>>> new_module = transformer.transform_module(module)
5454
>>> print(new_module.code)
55-
from typing import TypeVar
55+
from typing import TypeVar, Union
5656
def __wrapper_func_test():
5757
__test_T = TypeVar("__test_T", bound = Union[int, str])
5858
def test(x: __test_T) -> __test_T:
5959
return x
6060
return test
6161
test = __wrapper_func_test()
62+
>>> transformer = TransformTypeParametersCommand(CodemodContext())
6263
>>> module = cst.parse_module(\"""
6364
... class Test[T: int]:
6465
... def test[P: str](self, x: T, y: P) -> tuple[T, P]:
@@ -94,8 +95,7 @@ def remove_type_parameters[T: FunctionDef | ClassDef](
9495
new_name = type_param.param.name.with_changes(value=f"{prefix}{type_param.param.name.value}{suffix}")
9596

9697
AddImportsVisitor.add_needed_import(self.context, "typing", type_param.param.__class__.__name__)
97-
AddImportsVisitor.add_needed_import(self.context, "typing", "Union")
98-
statements.append(gen_type_param(type_param.param, new_name))
98+
statements.append(gen_type_param(type_param.param, new_name, self.context))
9999
slices.append(
100100
SubscriptElement(
101101
slice=Index(value=new_name),

pyfuture/codemod/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from enum import Enum
55

66
import libcst as cst
7-
from libcst.codemod import Codemod
7+
from libcst.codemod import Codemod, CodemodContext
8+
from libcst.codemod.visitors import AddImportsVisitor
89

910

1011
class RuleSet(Enum):
@@ -102,7 +103,9 @@ def _split_bit_or(_op: cst.BinaryOperation) -> list[cst.BaseExpression]:
102103

103104

104105
def gen_type_param(
105-
type_param: cst.TypeVar | cst.TypeVarTuple | cst.ParamSpec, type_name: cst.Name | None = None
106+
type_param: cst.TypeVar | cst.TypeVarTuple | cst.ParamSpec,
107+
type_name: cst.Name | None = None,
108+
context: CodemodContext | None = None,
106109
) -> cst.SimpleStatementLine:
107110
"""
108111
To generate type parameter definition statement.
@@ -128,6 +131,12 @@ def gen_type_param(
128131
if bound is not None:
129132
if isinstance(bound, cst.BinaryOperation):
130133
bound = transform_bit_or(bound) or bound
134+
if (
135+
isinstance(bound, cst.Subscript)
136+
and isinstance(bound_value := bound.value, cst.Name)
137+
and context is not None
138+
):
139+
AddImportsVisitor.add_needed_import(context, "typing", bound_value.value)
131140
args.append(cst.Arg(bound, keyword=cst.Name("bound")))
132141
return cst.SimpleStatementLine(
133142
[

0 commit comments

Comments
 (0)