Skip to content

Commit fdd25a4

Browse files
authored
Update trivial inference for Python 3.14 (#37248)
* Update trivial inference for Python 3.14 * correct comment * Address review coments * avoid none case being incorrect * fix docstring
1 parent 3783f58 commit fdd25a4

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _get_args(typ):
9595
A tuple of args.
9696
"""
9797
try:
98-
if typ.__args__ is None:
98+
if typ.__args__ is None or not isinstance(typ.__args__, tuple):
9999
return ()
100100
return typ.__args__
101101
except AttributeError:

sdks/python/apache_beam/typehints/native_type_compatibility_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def test_forward_reference(self):
337337
self.assertEqual(typehints.Any, convert_to_beam_type('int'))
338338
self.assertEqual(typehints.Any, convert_to_beam_type('typing.List[int]'))
339339
self.assertEqual(
340-
typehints.List[typehints.Any], convert_to_beam_type(typing.List['int']))
340+
typehints.List[typehints.Any], convert_to_beam_type(list['int']))
341341

342342
def test_convert_nested_to_beam_type(self):
343343
self.assertEqual(typehints.List[typing.Any], typehints.List[typehints.Any])

sdks/python/apache_beam/typehints/opcodes.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@
6363
else:
6464
_div_binop_args = frozenset()
6565

66+
if sys.version_info >= (3, 14):
67+
_NB_SUBSCR_OPCODE = [op[0] for op in opcode._nb_ops].index('NB_SUBSCR')
68+
else:
69+
_NB_SUBSCR_OPCODE = -1
70+
6671

6772
def pop_one(state, unused_arg):
6873
del state.stack[-1:]
@@ -151,6 +156,9 @@ def get_iter(state, unused_arg):
151156

152157
def symmetric_binary_op(state, arg, is_true_div=None):
153158
# TODO(robertwb): This may not be entirely correct...
159+
# BINARY_SUBSCR was rolled into BINARY_OP in 3.14.
160+
if arg == _NB_SUBSCR_OPCODE:
161+
return binary_subscr(state, arg)
154162
b, a = Const.unwrap(state.stack.pop()), Const.unwrap(state.stack.pop())
155163
if a == b:
156164
if a is int and b is int and (arg in _div_binop_args or is_true_div):
@@ -206,7 +214,10 @@ def binary_subscr(state, unused_arg):
206214
out = base._constraint_for_index(index.value)
207215
except IndexError:
208216
out = element_type(base)
209-
elif index == slice and isinstance(base, typehints.IndexableTypeConstraint):
217+
elif (index == slice or getattr(index, 'type', None) == slice) and isinstance(
218+
base, typehints.IndexableTypeConstraint):
219+
# The slice is treated as a const in 3.14, using this instead of
220+
# BINARY_SLICE
210221
out = base
211222
else:
212223
out = element_type(base)
@@ -483,20 +494,29 @@ def load_global(state, arg):
483494
state.stack.append(state.get_global(arg))
484495

485496

497+
def load_small_int(state, arg):
498+
state.stack.append(Const(arg))
499+
500+
486501
store_map = pop_two
487502

488503

489504
def load_fast(state, arg):
490505
state.stack.append(state.vars[arg])
491506

492507

508+
load_fast_borrow = load_fast
509+
510+
493511
def load_fast_load_fast(state, arg):
494512
arg1 = arg >> 4
495513
arg2 = arg & 15
496514
state.stack.append(state.vars[arg1])
497515
state.stack.append(state.vars[arg2])
498516

499517

518+
load_fast_borrow_load_fast_borrow = load_fast_load_fast
519+
500520
load_fast_check = load_fast
501521

502522

@@ -605,6 +625,8 @@ def set_function_attribute(state, arg):
605625
for t in state.stack[attr].tuple_types)
606626
new_func = types.FunctionType(
607627
func.code, func.globals, name=func.name, closure=closure)
628+
if arg & 0x10:
629+
new_func.__annotate__ = attr
608630
state.stack.append(Const(new_func))
609631

610632

sdks/python/apache_beam/typehints/trivial_inference.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
396396

397397
jump_multiplier = 2
398398

399+
# Python 3.14+ push nulls are used to signal kwargs for CALL_FUNCTION_EX
400+
# so there must be a little extra bookkeeping even if we don't care about
401+
# the nulls themselves.
402+
last_op_push_null = 0
403+
399404
last_pc = -1
400405
last_real_opname = opname = None
401406
while pc < end: # pylint: disable=too-many-nested-blocks
@@ -441,7 +446,8 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
441446
elif op in dis.haslocal:
442447
# Args to double-fast opcodes are bit manipulated, correct the arg
443448
# for printing + avoid the out-of-index
444-
if dis.opname[op] == 'LOAD_FAST_LOAD_FAST':
449+
if dis.opname[op] == 'LOAD_FAST_LOAD_FAST' or dis.opname[
450+
op] == "LOAD_FAST_BORROW_LOAD_FAST_BORROW":
445451
print(
446452
'(' + co.co_varnames[arg >> 4] + ', ' +
447453
co.co_varnames[arg & 15] + ')',
@@ -450,6 +456,8 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
450456
print('(' + co.co_varnames[arg & 15] + ')', end=' ')
451457
elif dis.opname[op] == 'STORE_FAST_STORE_FAST':
452458
pass
459+
elif dis.opname[op] == 'LOAD_DEREF':
460+
pass
453461
else:
454462
print('(' + co.co_varnames[arg] + ')', end=' ')
455463
elif op in dis.hascompare:
@@ -512,6 +520,12 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
512520
# stack[-has_kwargs]: Map of keyword args.
513521
# stack[-1 - has_kwargs]: Iterable of positional args.
514522
# stack[-2 - has_kwargs]: Function to call.
523+
if arg is None:
524+
# CALL_FUNCTION_EX does not take an arg in 3.14, instead the
525+
# signaling for kwargs is done via a PUSH_NULL instruction
526+
# right before CALL_FUNCTION_EX. A PUSH_NULL indicates that
527+
# there are no kwargs.
528+
arg = ~last_op_push_null
515529
has_kwargs: int = arg & 1
516530
pop_count = has_kwargs + 2
517531
if has_kwargs:
@@ -680,6 +694,9 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
680694
jmp_state = state.copy()
681695
jmp_state.stack.pop()
682696
state.stack.append(element_type(state.stack[-1]))
697+
elif opname == 'POP_ITER':
698+
# Introduced in 3.14.
699+
state.stack.pop()
683700
elif opname == 'COPY_FREE_VARS':
684701
# Helps with calling closures, but since we aren't executing
685702
# them we can treat this as a no-op
@@ -694,6 +711,10 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
694711
# We're treating this as a no-op to avoid having to check
695712
# for extra None values on the stack when we extract return
696713
# values
714+
last_op_push_null = 1
715+
pass
716+
elif opname == 'NOT_TAKEN':
717+
# NOT_TAKEN is a no-op introduced in 3.14.
697718
pass
698719
elif opname == 'PRECALL':
699720
# PRECALL is a no-op.
@@ -727,6 +748,10 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
727748
else:
728749
raise TypeInferenceError('unable to handle %s' % opname)
729750

751+
# Clear check for previous push_null.
752+
if opname != 'PUSH_NULL' and last_op_push_null == 1:
753+
last_op_push_null = 0
754+
730755
if jmp is not None:
731756
# TODO(robertwb): Is this guaranteed to converge?
732757
new_state = states[jmp] | jmp_state

0 commit comments

Comments
 (0)