Skip to content

Commit 93b4a3a

Browse files
committed
[mypy] contextmanager
1 parent 665f5ad commit 93b4a3a

File tree

1 file changed

+110
-1
lines changed

1 file changed

+110
-1
lines changed

mypyc/irbuild/statement.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111
reraise_exception_op,
112112
restore_exc_info_op,
113113
)
114-
from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op
114+
from mypyc.primitives.generic_ops import iter_op, next_op, next_raw_op, py_delattr_op
115115
from mypyc.primitives.misc_ops import (
116116
check_stop_op,
117117
coro_op,
@@ -887,6 +887,19 @@ def transform_with(
887887
is_async: bool,
888888
line: int,
889889
) -> None:
890+
891+
if (
892+
not is_async
893+
and isinstance(expr, mypy.nodes.CallExpr)
894+
and isinstance(expr.callee, mypy.nodes.RefExpr)
895+
and isinstance(dec := expr.callee.node, mypy.nodes.Decorator)
896+
and len(dec.decorators) == 1
897+
and isinstance(dec1 := dec.decorators[0], mypy.nodes.RefExpr)
898+
and dec1.node
899+
and dec1.node.fullname == "contextlib.contextmanager"
900+
):
901+
return _transform_with_contextmanager(builder, expr, target, body, line)
902+
890903
# This is basically a straight transcription of the Python code in PEP 343.
891904
# I don't actually understand why a bunch of it is the way it is.
892905
# We could probably optimize the case where the manager is compiled by us,
@@ -964,6 +977,102 @@ def finally_body() -> None:
964977
)
965978

966979

980+
def _transform_with_contextmanager(
981+
builder: IRBuilder,
982+
expr: mypy.nodes.CallExpr,
983+
target: Lvalue | None,
984+
with_body: GenFunc,
985+
line: int,
986+
) -> None:
987+
assert isinstance(expr.callee, mypy.nodes.RefExpr)
988+
dec = expr.callee.node
989+
assert isinstance(dec, mypy.nodes.Decorator)
990+
991+
# mgrv = ctx.__wrapped__(*args, **kwargs)
992+
wrapped_call = mypy.nodes.CallExpr(
993+
mypy.nodes.MemberExpr(expr.callee, "__wrapped__"), expr.args, expr.arg_kinds, expr.arg_names
994+
)
995+
gen = builder.accept(wrapped_call)
996+
997+
# try:
998+
# target = next(gen)
999+
# except StopIteration:
1000+
# raise RuntimeError("generator didn't yield") from None
1001+
mgr_target = builder.call_c(next_raw_op, [gen], line)
1002+
1003+
runtime_block, main_block = BasicBlock(), BasicBlock()
1004+
builder.add(Branch(mgr_target, runtime_block, main_block, Branch.IS_ERROR))
1005+
1006+
builder.activate_block(runtime_block)
1007+
builder.add(RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, "generator didn't yield", line))
1008+
builder.add(Unreachable())
1009+
1010+
builder.activate_block(main_block)
1011+
1012+
# try:
1013+
# {body}
1014+
1015+
def try_body() -> None:
1016+
if target:
1017+
builder.assign(builder.get_assignment_target(target), mgr_target, line)
1018+
with_body()
1019+
1020+
# except Exception as e:
1021+
# exc = True
1022+
# try:
1023+
# gen.throw(e)
1024+
# except StopIteration as e2:
1025+
# if e2 is not e:
1026+
# raise
1027+
# return
1028+
# except RuntimeError:
1029+
# # TODO: some other stuff
1030+
# ...
1031+
# except BaseException:
1032+
# # TODO: some other stuff
1033+
# ...
1034+
1035+
def except_body() -> None:
1036+
builder.add(RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, "TODO", line))
1037+
builder.add(Unreachable())
1038+
1039+
# TODO: actually do the exceptions
1040+
handlers = [(None, None, except_body)]
1041+
1042+
# else:
1043+
# try:
1044+
# next(gen)
1045+
# except StopIteration:
1046+
# pass
1047+
# else:
1048+
# try:
1049+
# raise RuntimeError("generator didn't stop")
1050+
# finally:
1051+
# gen.close()
1052+
1053+
def else_body() -> None:
1054+
value = builder.call_c(next_raw_op, [builder.read(gen)], line)
1055+
stop_block, close_block = BasicBlock(), BasicBlock()
1056+
builder.add(Branch(value, stop_block, close_block, Branch.IS_ERROR))
1057+
1058+
builder.activate_block(close_block)
1059+
# TODO: this isn't exactly the right order
1060+
builder.py_call(builder.py_get_attr(gen, "close", line), [], line)
1061+
builder.add(RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, "generator didn't stop", line))
1062+
builder.add(Unreachable())
1063+
1064+
builder.activate_block(stop_block)
1065+
builder.call_c(error_catch_op, [], -1)
1066+
1067+
transform_try_except(
1068+
builder,
1069+
try_body,
1070+
handlers,
1071+
else_body,
1072+
line,
1073+
)
1074+
1075+
9671076
def transform_with_stmt(builder: IRBuilder, o: WithStmt) -> None:
9681077
# Generate separate logic for each expr in it, left to right
9691078
def generate(i: int) -> None:

0 commit comments

Comments
 (0)