Skip to content

Commit 2fe6db1

Browse files
authored
cp.sum support generator (cvxpy#3058)
* document current sum of list behaviour * support generator input * use GeneratorType to make generator support cleaner --------- Co-authored-by: Muhammad Yasirroni <yasirroni@users.noreply.github.com>
1 parent 74867aa commit 2fe6db1

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

cvxpy/atoms/affine/sum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616
import builtins
1717
from functools import wraps
18+
from types import GeneratorType
1819
from typing import Optional, Tuple
1920

2021
import numpy as np
@@ -145,7 +146,7 @@ def sum(expr, axis: Optional[int] = None, keepdims: bool = False):
145146
"""
146147
Wrapper for Sum class.
147148
"""
148-
if isinstance(expr, list):
149+
if isinstance(expr, (list, GeneratorType)):
149150
return builtins.sum(expr)
150151
else:
151152
return Sum(expr, axis, keepdims)

cvxpy/tests/test_atoms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,14 @@ def test_sum(self) -> None:
492492
self.assertEqual(cp.sum(Variable(2)).shape, tuple())
493493
self.assertEqual(cp.sum(Variable(2)).curvature, s.AFFINE)
494494
self.assertEqual(cp.sum(Variable((2, 1)), keepdims=True).shape, (1, 1))
495+
496+
# Iterables and Generators
497+
self.assertEqual(cp.sum([Variable(1) for _ in range(3)]).shape, (1,))
498+
self.assertEqual(cp.sum([Variable(2) for _ in range(3)]).shape, (2,))
499+
self.assertEqual(cp.sum(Variable(1) for _ in range(3)).shape, (1,))
500+
self.assertEqual(cp.sum(Variable(2) for _ in range(3)).shape, (2,))
501+
self.assertEqual(cp.sum(range(3)).shape, tuple())
502+
495503
# Mixed curvature.
496504
mat = np.array([[1, -1]])
497505
self.assertEqual(cp.sum(mat @ cp.square(Variable(2))).curvature, s.UNKNOWN)

0 commit comments

Comments
 (0)