Skip to content

Commit bd2f0ed

Browse files
authored
initial Python 3.14 support (#2679)
1 parent f946af6 commit bd2f0ed

File tree

4 files changed

+141
-20
lines changed

4 files changed

+141
-20
lines changed

.github/workflows/ci-testing.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ jobs:
3737
- { os: "ubuntu-24.04", python-version: "3.11", requires: "latest" }
3838
- { os: "ubuntu-24.04", python-version: "3.12", requires: "latest" }
3939
- { os: "ubuntu-24.04", python-version: "3.13", requires: "latest" }
40+
- { os: "ubuntu-24.04", python-version: "3.14", requires: "latest" }
4041
exclude:
4142
- { os: "windows-latest", suite: "ops" }
4243
- { os: "windows-latest", suite: "grads" }

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ authors = [
1414
{ name = "Lightning AI", email = "[email protected]" },
1515
]
1616
license = "Apache-2.0"
17-
requires-python = ">=3.10, <3.14"
17+
requires-python = ">=3.10, <3.15"
1818
keywords = ["deep learning", "AI", "compiler"]
1919
classifiers=[
2020
"Environment :: Console",

thunder/core/interpreter.py

Lines changed: 130 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def nexti(self, inst: dis.Instruction):
11431143
if (3, 9) <= sys.version_info < (3, 11):
11441144
if inst.starts_line is not None:
11451145
self.positions = Positions(inst.starts_line, inst.starts_line, 0, 999)
1146-
elif (3, 11) <= sys.version_info < (3, 14):
1146+
elif (3, 11) <= sys.version_info < (3, 15):
11471147
if inst.positions is not None:
11481148
self.positions = inst.positions
11491149
else:
@@ -3263,7 +3263,7 @@ def _async_gen_wrap_handler(inst: dis.Instruction, /, stack: InterpreterStack, *
32633263

32643264

32653265
# https://docs.python.org/3.10/library/dis.html#opcode-BEFORE_ASYNC_WITH
3266-
@register_opcode_handler("BEFORE_ASYNC_WITH")
3266+
@register_opcode_handler("BEFORE_ASYNC_WITH", max_ver=(3, 13))
32673267
def _before_async_with_handler(
32683268
inst: dis.Instruction, /, stack: InterpreterStack, **kwargs
32693269
) -> None | INTERPRETER_SIGNALS:
@@ -3294,7 +3294,7 @@ def _before_async_with_handler(
32943294

32953295

32963296
# https://docs.python.org/3.11/library/dis.html#opcode-BEFORE_WITH
3297-
@register_opcode_handler("BEFORE_WITH", min_ver=(3, 11))
3297+
@register_opcode_handler("BEFORE_WITH", min_ver=(3, 11), max_ver=(3, 13))
32983298
def _before_with_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS:
32993299
mgr = stack.pop()
33003300

@@ -3349,6 +3349,7 @@ class BINARY_OP(enum.Enum):
33493349
ISUB = 23
33503350
ITRUEDIV = 24
33513351
IXOR = 25
3352+
SUBSCR = 26
33523353

33533354

33543355
def _binary_op(stack: InterpreterStack, op: BINARY_OP, a, b):
@@ -3384,6 +3385,9 @@ def _binary_op(stack: InterpreterStack, op: BINARY_OP, a, b):
33843385
assert type(op) is BINARY_OP
33853386
idx: int = op.value
33863387

3388+
if idx == BINARY_OP.SUBSCR.value:
3389+
return _binary_subscr(stack, a, b)
3390+
33873391
res = Py_NULL()
33883392
binop_name, *_ = ops[idx]
33893393
_, left_method, right_method = ops[idx % BINARY_OP.IADD.value]
@@ -3631,11 +3635,15 @@ def impl(container, start, end):
36313635

36323636

36333637
# https://docs.python.org/3.10/library/dis.html#opcode-BINARY_SUBSCR
3634-
@register_opcode_handler("BINARY_SUBSCR")
3638+
@register_opcode_handler("BINARY_SUBSCR", max_ver=(3, 13))
36353639
def _binary_subscr_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS:
36363640
tos = stack.pop_wrapped()
36373641
tos1 = stack.pop_wrapped()
36383642

3643+
return _binary_subscr(stack, tos1, tos)
3644+
3645+
3646+
def _binary_subscr(stack, tos1, tos):
36393647
def class_getitem_impl(cls, index):
36403648
return cls.__class_getitem__(index)
36413649

@@ -3654,7 +3662,7 @@ def getitem_impl(obj, index):
36543662

36553663

36563664
# https://docs.python.org/3.10/library/dis.html#opcode-BUILD_CONST_KEY_MAP
3657-
@register_opcode_handler("BUILD_CONST_KEY_MAP")
3665+
@register_opcode_handler("BUILD_CONST_KEY_MAP", max_ver=(3, 13))
36583666
def _build_const_key_map_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
36593667
assert type(inst.arg) is int
36603668
count: int = inst.arg
@@ -3832,8 +3840,15 @@ def _call_function_handler(inst: dis.Instruction, /, stack: InterpreterStack, **
38323840
def _call_function_ex_handler(
38333841
inst: dis.Instruction, /, stack: InterpreterStack, **kwargs
38343842
) -> None | INTERPRETER_SIGNALS:
3835-
assert type(inst.arg) is int
3836-
kwargs = stack.pop_wrapped() if inst.arg & 0x01 else {}
3843+
if sys.version_info < (3, 14):
3844+
inst_arg = inst.arg
3845+
else:
3846+
inst_arg = 1
3847+
assert type(inst_arg) is int
3848+
kwargs = stack.pop_wrapped() if inst_arg & 0x01 else {}
3849+
if sys.version_info >= (3, 14) and wrapped_isinstance(kwargs, Py_NULL):
3850+
kwargs = {}
3851+
38373852
assert wrapped_isinstance(kwargs, Mapping)
38383853
args = stack.pop_wrapped()
38393854
assert wrapped_isinstance(args, Iterable)
@@ -4314,7 +4329,9 @@ def _end_async_for_handler_3_11(
43144329
**kwargs,
43154330
) -> None | INTERPRETER_SIGNALS:
43164331
runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx()
4317-
assert inst.arg is None
4332+
if sys.version_info < (3, 14):
4333+
# 3.14+ has an (unused) int arg pointing to the END_SEND
4334+
assert inst.arg is None
43184335

43194336
val = stack.pop()
43204337
assert isinstance(val, BaseException)
@@ -4855,7 +4872,7 @@ def _list_to_tuple_handler(inst: dis.Instruction, /, stack: InterpreterStack, **
48554872

48564873

48574874
# https://docs.python.org/3.13/library/dis.html#opcode-LOAD_ASSERTION_ERROR
4858-
@register_opcode_handler("LOAD_ASSERTION_ERROR")
4875+
@register_opcode_handler("LOAD_ASSERTION_ERROR", max_ver=(3, 13))
48594876
def _load_assertion_error_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
48604877
stack.append(wrap_const(AssertionError))
48614878

@@ -4889,7 +4906,35 @@ def _load_attr_handler_3_12(
48894906
return load_method_helper(obj, name, stack)
48904907

48914908

4892-
# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_ATTR
4909+
# https://docs.python.org/3.14/library/dis.html#opcode-LOAD_SPECIAL
4910+
@register_opcode_handler("LOAD_SPECIAL", min_ver=(3, 14))
4911+
def _load_special_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS:
4912+
arg = inst.arg
4913+
mgr = stack.pop_wrapped()
4914+
4915+
type_name = type(unwrap(mgr)).__name__
4916+
special_methods = [
4917+
("__enter__", f"'{type_name}' object does not support the context manager protocol (missed __enter__ method)"),
4918+
("__exit__", f"'{type_name}' object does not support the context manager protocol (missed __exit__ method)"),
4919+
(
4920+
"__aenter__",
4921+
f"'{type_name}' object does not support the asynchronous context manager protocol (missed __aenter__ method)",
4922+
),
4923+
(
4924+
"__aexit__",
4925+
f"'{type_name}' object does not support the asynchronous context manager protocol (missed __aexit__ method)",
4926+
),
4927+
]
4928+
name, error_msg = special_methods[arg]
4929+
res = load_method_helper(mgr, wrap_const(name), stack)
4930+
if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
4931+
# clear previous error?
4932+
return do_raise(error_msg)
4933+
4934+
return res
4935+
4936+
4937+
# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_SUPER_ATTR
48934938
@register_opcode_handler("LOAD_SUPER_ATTR", min_ver=(3, 12))
48944939
def _load_super_attr_handler(
48954940
inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs
@@ -4944,6 +4989,23 @@ def _load_closure_handler(
49444989
stack.append(val)
49454990

49464991

4992+
# https://docs.python.org/3.14/library/dis.html#opcode-LOAD_COMMON_CONSTANT
4993+
@register_opcode_handler("LOAD_COMMON_CONSTANT", min_ver=(3, 14))
4994+
def _load_common_constant_handler(inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs) -> None:
4995+
assert type(inst.arg) is int
4996+
4997+
common_consts = [
4998+
AssertionError,
4999+
NotImplementedError,
5000+
tuple,
5001+
all,
5002+
any,
5003+
]
5004+
constant = wrap_const(common_consts[inst.arg])
5005+
constant = const_callback(constant)
5006+
stack.append(constant)
5007+
5008+
49475009
# https://docs.python.org/3.10/library/dis.html#opcode-LOAD_CONST
49485010
@register_opcode_handler("LOAD_CONST")
49495011
def _load_const_handler(inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs) -> None:
@@ -4955,6 +5017,16 @@ def _load_const_handler(inst: dis.Instruction, /, stack: InterpreterStack, co: C
49555017
stack.append(constant)
49565018

49575019

5020+
# https://docs.python.org/3.14/library/dis.html#opcode-LOAD_SMALL_INT
5021+
@register_opcode_handler("LOAD_SMALL_INT", min_ver=(3, 14))
5022+
def _load_small_int_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
5023+
assert type(inst.arg) is int
5024+
5025+
constant = wrap_const(inst.arg)
5026+
constant = const_callback(constant)
5027+
stack.append(constant)
5028+
5029+
49585030
# https://docs.python.org/3.10/library/dis.html#opcode-LOAD_DEREF
49595031
@register_opcode_handler("LOAD_DEREF")
49605032
def _load_deref_handler(
@@ -4993,9 +5065,10 @@ def _load_deref_handler(
49935065
# https://docs.python.org/3.10/library/dis.html#opcode-LOAD_FAST
49945066
# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_FAST_CHECK
49955067
# LOAD_FAST for Python <3.12 is LOAD_FAST_CHECK
5068+
@register_opcode_handler("LOAD_FAST_BORROW", min_ver=(3, 14))
49965069
@register_opcode_handler("LOAD_FAST_CHECK", min_ver=(3, 12))
49975070
@register_opcode_handler("LOAD_FAST")
4998-
def _load_fast_check_handler(
5071+
def _load_fast_check_load_fast_borrow_handler(
49995072
inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs
50005073
) -> None | INTERPRETER_SIGNALS:
50015074
assert isinstance(inst.arg, int)
@@ -5019,8 +5092,9 @@ def _load_fast_check_handler(
50195092

50205093

50215094
# https://docs.python.org/3.13/library/dis.html#opcode-LOAD_FAST_LOAD_FAST
5095+
@register_opcode_handler("LOAD_FAST_BORROW_LOAD_FAST_BORROW", min_ver=(3, 14))
50225096
@register_opcode_handler("LOAD_FAST_LOAD_FAST", min_ver=(3, 13))
5023-
def _load_fast_load_fast_handler(
5097+
def _load_fast_load_fast_load_fast_borrow_load_fast_borrow_handler(
50245098
inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs
50255099
) -> None | INTERPRETER_SIGNALS:
50265100
assert isinstance(inst.arg, int)
@@ -5126,7 +5200,7 @@ def _load_global_handler(
51265200

51275201

51285202
# https://docs.python.org/3.11/library/dis.html#opcode-LOAD_METHOD
5129-
@register_opcode_handler("LOAD_METHOD")
5203+
@register_opcode_handler("LOAD_METHOD", max_ver=(3, 13))
51305204
def _load_method_handler(
51315205
inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs
51325206
) -> None | INTERPRETER_SIGNALS:
@@ -5360,6 +5434,13 @@ def _set_function_attribute_handler(
53605434
stack.append(fn)
53615435
return
53625436

5437+
if flag == 0x10:
5438+
annotate = stack.pop()
5439+
assert annotate is None or callable(annotate)
5440+
fn.__annotate__ = annotate
5441+
stack.append(fn)
5442+
return
5443+
53635444
assert False, f"Flag value 0x{flag:x} unexpected in SET_FUNCTION_ATTRIBUTE"
53645445

53655446

@@ -5503,6 +5584,12 @@ def _match_sequence_handler(inst: dis.Instruction, /, stack: InterpreterStack, *
55035584
stack.push(supported_sequence)
55045585

55055586

5587+
# https://docs.python.org/3.14/library/dis.html#opcode-NOT_TAKEN
5588+
@register_opcode_handler("NOT_TAKEN", min_ver=(3, 14))
5589+
def _not_taken_handler(inst: dis.Instruction, /, **kwargs) -> None:
5590+
pass
5591+
5592+
55065593
# https://docs.python.org/3.10/library/dis.html#opcode-NOP
55075594
@register_opcode_handler("NOP")
55085595
def _nop_handler(inst: dis.Instruction, /, **kwargs) -> None:
@@ -5777,6 +5864,12 @@ def _pop_top_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs
57775864
stack.pop_wrapped()
57785865

57795866

5867+
# https://docs.python.org/3.14/library/dis.html#opcode-POP_ITER
5868+
@register_opcode_handler("POP_ITER", min_ver=(3, 14))
5869+
def _pop_iter_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
5870+
stack.pop_wrapped()
5871+
5872+
57805873
# Returns either
57815874
def do_raise(exc: Any = Py_NULL(), cause: Any = Py_NULL()) -> Literal[INTERPRETER_SIGNALS.EXCEPTION_RAISED]:
57825875
# Get the type and exception being raised
@@ -5944,7 +6037,7 @@ def _reraise_handler_3_11(
59446037

59456038

59466039
# https://docs.python.org/3.12/library/dis.html#opcode-RETURN_CONST
5947-
@register_opcode_handler("RETURN_CONST", min_ver=(3, 12))
6040+
@register_opcode_handler("RETURN_CONST", min_ver=(3, 12), max_ver=(3, 13))
59486041
def _return_const_handler(
59496042
inst: dis.Instruction, /, co: CodeType, stack: InterpreterStack, **kwargs
59506043
) -> int | None | INTERPRETER_SIGNALS:
@@ -6576,7 +6669,7 @@ def _with_except_start_handler_3_10(
65766669

65776670

65786671
# https://docs.python.org/3.11/library/dis.html#opcode-WITH_EXCEPT_START
6579-
@register_opcode_handler("WITH_EXCEPT_START", min_ver=(3, 11))
6672+
@register_opcode_handler("WITH_EXCEPT_START", min_ver=(3, 11), max_ver=(3, 13))
65806673
def _with_except_start_handler_3_11(
65816674
inst: dis.Instruction, *, inst_ptr: int, stack: InterpreterStack, try_stack: list[PyTryBlock], **kwargs
65826675
) -> None | INTERPRETER_SIGNALS:
@@ -6590,6 +6683,27 @@ def _with_except_start_handler_3_11(
65906683
return check_and_append(stack, _interpret_call_with_unwrapping(exit_func, exc, val, tb))
65916684

65926685

6686+
# https://docs.python.org/3.14/library/dis.html#opcode-WITH_EXCEPT_START
6687+
@register_opcode_handler("WITH_EXCEPT_START", min_ver=(3, 14))
6688+
def _with_except_start_handler_3_11(
6689+
inst: dis.Instruction, *, inst_ptr: int, stack: InterpreterStack, try_stack: list[PyTryBlock], **kwargs
6690+
) -> None | INTERPRETER_SIGNALS:
6691+
# in 3.11 the exception representation changed to only val
6692+
val = stack[-1]
6693+
exc = type(val)
6694+
tb = val.__traceback__
6695+
6696+
assert isinstance(stack[-3], int)
6697+
exit_meth_or_func = stack[-5]
6698+
exit_meth_self_or_null = stack[-4]
6699+
if isinstance(exit_meth_self_or_null, Py_NULL):
6700+
return check_and_append(stack, _interpret_call_with_unwrapping(exit_meth_or_func, exc, val, tb))
6701+
else:
6702+
return check_and_append(
6703+
stack, _interpret_call_with_unwrapping(exit_meth_or_func, exit_meth_self_or_null, exc, val, tb)
6704+
)
6705+
6706+
65936707
# https://docs.python.org/3.10/library/dis.html#opcode-YIELD_FROM
65946708
@register_opcode_handler("YIELD_FROM", max_ver=(3, 10))
65956709
def _yield_from_handler(
@@ -7157,7 +7271,7 @@ def _setup_frame_and_run_python_function(
71577271
for i, (name, value) in enumerate(zip(code.co_freevars, closure)):
71587272
local = freevar_callback(name, value, fn=wrapped_fn, idx=i)
71597273
localsplus.append(local)
7160-
elif (3, 11) <= sys.version_info < (3, 14):
7274+
elif (3, 11) <= sys.version_info < (3, 15):
71617275
assert len(code.co_varnames) == code.co_nlocals
71627276
for n in code.co_varnames:
71637277
local = locals_dict.get(n, Py_NULL())

thunder/tests/test_interpreter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,10 @@ def kwargsplat(**kwargs):
340340
assert is_jitting_with_raise() == jitting
341341
return foo(**kwargs)
342342

343-
assert any(i.opname == "CALL_FUNCTION_EX" and not i.arg & 1 for i in dis.get_instructions(argsplat))
344-
assert any(i.opname == "CALL_FUNCTION_EX" and i.arg & 1 for i in dis.get_instructions(kwargsplat))
343+
if sys.version_info < (3, 14):
344+
# Python 3.14 has no arg in call function ex.
345+
assert any(i.opname == "CALL_FUNCTION_EX" and not i.arg & 1 for i in dis.get_instructions(argsplat))
346+
assert any(i.opname == "CALL_FUNCTION_EX" and i.arg & 1 for i in dis.get_instructions(kwargsplat))
345347

346348
kwargs = {"a": 1, "b": 2}
347349

@@ -357,6 +359,10 @@ def kwargsplat(**kwargs):
357359
assert_close(res2, jres2)
358360

359361

362+
@pytest.mark.skipif(
363+
sys.version_info >= (3, 14),
364+
reason="Python 3.14+ do not implement BUILD_CONST_KEY_MAP",
365+
)
360366
def test_build_const_key_map(jit):
361367
def fn1(a, b):
362368
return {"a": a, "b": b}
@@ -1074,7 +1080,7 @@ def foo(x):
10741080

10751081
jfoo = jit(foo)
10761082

1077-
with pytest.raises(Exception, match=r"reduce\(\) takes no keyword arguments"):
1083+
with pytest.raises(Exception, match=r"reduce\(\) takes .* arguments"): # Varies for 3.10-3.13, 3.14
10781084
foo((1, 2, 3))
10791085

10801086
with pytest.raises(Exception, match=r"got some positional-only arguments passed as keyword arguments"):

0 commit comments

Comments
 (0)