diff --git a/mypyc/doc/list_operations.rst b/mypyc/doc/list_operations.rst index 378568865501..bb4681266cab 100644 --- a/mypyc/doc/list_operations.rst +++ b/mypyc/doc/list_operations.rst @@ -33,7 +33,7 @@ Operators * ``lst[n]`` (get item by integer index) * ``lst[n:m]``, ``lst[n:]``, ``lst[:m]``, ``lst[:]`` (slicing) * ``lst1 + lst2``, ``lst += iter`` -* ``lst * n``, ``n * lst`` +* ``lst * n``, ``n * lst``, ``lst *= n`` * ``obj in lst`` Statements diff --git a/mypyc/doc/tuple_operations.rst b/mypyc/doc/tuple_operations.rst index ed603fa9982d..4c9da9b894af 100644 --- a/mypyc/doc/tuple_operations.rst +++ b/mypyc/doc/tuple_operations.rst @@ -22,6 +22,7 @@ Operators * ``tup[n]`` (integer index) * ``tup[n:m]``, ``tup[n:]``, ``tup[:m]`` (slicing) * ``tup1 + tup2`` +* ``tup * n``, ``n * tup`` Statements ---------- diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 7b192e747595..aeb559a50a7a 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -664,6 +664,7 @@ int CPyList_Remove(PyObject *list, PyObject *obj); CPyTagged CPyList_Index(PyObject *list, PyObject *obj); PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size); PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq); +PyObject *CPySequence_InPlaceMultiply(PyObject *seq, CPyTagged t_size); PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end); PyObject *CPyList_Copy(PyObject *list); int CPySequence_Check(PyObject *obj); diff --git a/mypyc/lib-rt/list_ops.c b/mypyc/lib-rt/list_ops.c index 8388e1eea73a..b47fcec8ffe9 100644 --- a/mypyc/lib-rt/list_ops.c +++ b/mypyc/lib-rt/list_ops.c @@ -331,6 +331,14 @@ PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq) { return CPySequence_Multiply(seq, t_size); } +PyObject *CPySequence_InPlaceMultiply(PyObject *seq, CPyTagged t_size) { + Py_ssize_t size = CPyTagged_AsSsize_t(t_size); + if (size == -1 && PyErr_Occurred()) { + return NULL; + } + return PySequence_InPlaceRepeat(seq, size); +} + PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) { if (likely(PyList_CheckExact(obj) && CPyTagged_CheckShort(start) && CPyTagged_CheckShort(end))) { diff --git a/mypyc/primitives/list_ops.py b/mypyc/primitives/list_ops.py index 5cc8b3c0d1c6..6063fdfd680e 100644 --- a/mypyc/primitives/list_ops.py +++ b/mypyc/primitives/list_ops.py @@ -307,6 +307,15 @@ error_kind=ERR_MAGIC, ) +# list *= int +binary_op( + name="*=", + arg_types=[list_rprimitive, int_rprimitive], + return_type=list_rprimitive, + c_function_name="CPySequence_InPlaceMultiply", + error_kind=ERR_MAGIC, +) + # list[begin:end] list_slice_op = custom_op( arg_types=[list_rprimitive, int_rprimitive, int_rprimitive], diff --git a/mypyc/primitives/tuple_ops.py b/mypyc/primitives/tuple_ops.py index f28d4ca5ec7a..a9bbaa80fb5c 100644 --- a/mypyc/primitives/tuple_ops.py +++ b/mypyc/primitives/tuple_ops.py @@ -83,6 +83,24 @@ error_kind=ERR_MAGIC, ) +# tuple * int +binary_op( + name="*", + arg_types=[tuple_rprimitive, int_rprimitive], + return_type=tuple_rprimitive, + c_function_name="CPySequence_Multiply", + error_kind=ERR_MAGIC, +) + +# int * tuple +binary_op( + name="*", + arg_types=[int_rprimitive, tuple_rprimitive], + return_type=tuple_rprimitive, + c_function_name="CPySequence_RMultiply", + error_kind=ERR_MAGIC, +) + # tuple[begin:end] tuple_slice_op = custom_op( arg_types=[tuple_rprimitive, int_rprimitive, int_rprimitive], diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index e82c79459709..4e9b917ad03b 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -212,6 +212,8 @@ def __contains__(self, item: object) -> int: ... def __add__(self, value: Tuple[T_co, ...], /) -> Tuple[T_co, ...]: ... @overload def __add__(self, value: Tuple[_T, ...], /) -> Tuple[T_co | _T, ...]: ... + def __mul__(self, value: int, /) -> Tuple[T_co, ...]: ... + def __rmul__(self, value: int, /) -> Tuple[T_co, ...]: ... class function: pass @@ -225,6 +227,7 @@ def __setitem__(self, i: int, o: _T) -> None: pass def __delitem__(self, i: int) -> None: pass def __mul__(self, i: int) -> List[_T]: pass def __rmul__(self, i: int) -> List[_T]: pass + def __imul__(self, i: int) -> List[_T]: ... def __iter__(self) -> Iterator[_T]: pass def __len__(self) -> int: pass def __contains__(self, item: object) -> int: ... diff --git a/mypyc/test-data/irbuild-lists.test b/mypyc/test-data/irbuild-lists.test index b7ba1a783bb7..c2e2df133fc5 100644 --- a/mypyc/test-data/irbuild-lists.test +++ b/mypyc/test-data/irbuild-lists.test @@ -194,6 +194,18 @@ L0: b = r4 return 1 +[case testListIMultiply] +from typing import List +def f(a: List[int]) -> None: + a *= 2 +[out] +def f(a): + a, r0 :: list +L0: + r0 = CPySequence_InPlaceMultiply(a, 4) + a = r0 + return 1 + [case testListLen] from typing import List def f(a: List[int]) -> int: diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index e7280bb3b552..582391ff6f98 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -418,3 +418,38 @@ L0: r3 = unbox(tuple[int, int, int, int], r2) c = r3 return 1 + +[case testTupleMultiply] +from typing import Tuple +def f(a: Tuple[int]) -> None: + b = a * 2 + c = 3 * (2,) +def g(a: Tuple[int, ...]) -> None: + b = a * 2 +[out] +def f(a): + a :: tuple[int] + r0 :: object + r1 :: tuple + r2, b :: tuple[int, int] + r3 :: tuple[int] + r4 :: object + r5 :: tuple + r6, c :: tuple[int, int, int] +L0: + r0 = box(tuple[int], a) + r1 = CPySequence_Multiply(r0, 4) + r2 = unbox(tuple[int, int], r1) + b = r2 + r3 = (4) + r4 = box(tuple[int], r3) + r5 = CPySequence_RMultiply(6, r4) + r6 = unbox(tuple[int, int, int], r5) + c = r6 + return 1 +def g(a): + a, r0, b :: tuple +L0: + r0 = CPySequence_Multiply(a, 4) + b = r0 + return 1 diff --git a/mypyc/test-data/run-lists.test b/mypyc/test-data/run-lists.test index 84168f7254f5..b6d9a811d910 100644 --- a/mypyc/test-data/run-lists.test +++ b/mypyc/test-data/run-lists.test @@ -313,6 +313,13 @@ def test_add() -> None: assert in_place_add({3: "", 4: ""}) == res assert in_place_add(range(3, 5)) == res +def test_multiply() -> None: + l1 = [1] + assert l1 * 3 == [1, 1, 1] + assert 3 * l1 == [1, 1, 1] + l1 *= 3 + assert l1 == [1, 1, 1] + [case testOperatorInExpression] def tuple_in_int0(i: int) -> bool: diff --git a/mypyc/test-data/run-tuples.test b/mypyc/test-data/run-tuples.test index afd3a956b871..fcf1def9b8fc 100644 --- a/mypyc/test-data/run-tuples.test +++ b/mypyc/test-data/run-tuples.test @@ -261,3 +261,12 @@ def test_add() -> None: assert (1, 2) + (3, 4) == res with assertRaises(TypeError, 'can only concatenate tuple (not "list") to tuple'): assert (1, 2) + cast(Any, [3, 4]) == res + +def multiply(a: Tuple[Any, ...], b: int) -> Tuple[Any, ...]: + return a * b + +def test_multiply() -> None: + res = (1, 1, 1) + assert (1,) * 3 == res + assert 3 * (1,) == res + assert multiply((1,), 3) == res