Skip to content

Commit c52323d

Browse files
committed
feat: specialize isinstance for tuple of primitive types
1 parent de8b296 commit c52323d

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed

mypyc/irbuild/specialize.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RefExpr,
3131
StrExpr,
3232
SuperExpr,
33+
SymbolNode,
3334
TupleExpr,
3435
Var,
3536
)
@@ -653,13 +654,51 @@ def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
653654
obj = builder.accept(obj_expr, can_borrow=can_borrow)
654655
return builder.builder.isinstance_helper(obj, irs, expr.line)
655656

657+
if isinstance(type_expr, TupleExpr) and type_expr.items:
658+
nodes: list[SymbolNode | None] = []
659+
for item in type_expr.items:
660+
if not isinstance(item, RefExpr):
661+
return None
662+
if item.node is None:
663+
return None
664+
if item.node.fullname not in nodes:
665+
nodes.append(item.node.fullname)
666+
667+
descs = [isinstance_primitives.get(fullname) for fullname in nodes]
668+
669+
obj = builder.accept(expr.args[0])
670+
671+
retval = Register(bool_rprimitive)
672+
pass_block = BasicBlock()
673+
fail_block = BasicBlock()
674+
exit_block = BasicBlock()
675+
676+
# Chain the checks: if any succeed, jump to pass_block; else, continue
677+
for i, desc in enumerate(descs):
678+
is_last = (i == len(descs) - 1)
679+
next_block = fail_block if is_last else BasicBlock()
680+
builder.add_bool_branch(builder.primitive_op(desc, [obj], expr.line), pass_block, next_block)
681+
if not is_last:
682+
builder.activate_block(next_block)
683+
684+
# If any check passed
685+
builder.activate_block(pass_block)
686+
builder.assign(retval, builder.true(), expr.line)
687+
builder.goto(exit_block)
688+
689+
# If all checks failed
690+
builder.activate_block(fail_block)
691+
builder.assign(retval, builder.false(), expr.line)
692+
builder.goto(exit_block)
693+
694+
# Return the result
695+
builder.activate_block(exit_block)
696+
return retval
697+
656698
if isinstance(type_expr, RefExpr):
657-
node = type_expr.node
658-
if node:
659-
desc = isinstance_primitives.get(node.fullname)
660-
if desc:
661-
obj = builder.accept(obj_expr)
662-
return builder.primitive_op(desc, [obj], expr.line)
699+
if node := type_expr.node:
700+
if desc := isinstance_primitives.get(node.fullname):
701+
return builder.primitive_op(desc, [builder.accept(obj_expr)], expr.line)
663702

664703
elif isinstance(type_expr, TupleExpr):
665704
node_names: list[str] = []

0 commit comments

Comments
 (0)