Skip to content

Commit ca6aeb8

Browse files
committed
add itertools.combinations
1 parent b19072b commit ca6aeb8

File tree

1 file changed

+73
-2
lines changed

1 file changed

+73
-2
lines changed

graalpython/lib-graalpython/itertools.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __next__(self):
6969
class starmap():
7070
pass
7171

72+
7273
class islice(object):
7374
def __init__(self, iterable, *args):
7475
self._iterable = enumerate(iter(iterable))
@@ -419,12 +420,13 @@ def __next__(self):
419420
if not self.func(n):
420421
return n
421422

423+
422424
class takewhile(object):
423425
"""Make an iterator that returns elements from the iterable as
424426
long as the predicate is true.
425427
426428
Equivalent to :
427-
429+
428430
def takewhile(predicate, iterable):
429431
for x in iterable:
430432
if predicate(x):
@@ -446,6 +448,7 @@ def __next__(self):
446448
raise StopIteration()
447449
return value
448450

451+
449452
class groupby(object):
450453
"""Make an iterator that returns consecutive keys and groups from the
451454
iterable. The key is a function computing a key value for each
@@ -464,7 +467,7 @@ class groupby(object):
464467
for k, g in groupby(data, keyfunc):
465468
groups.append(list(g)) # Store group iterator as a list
466469
uniquekeys.append(k)
467-
"""
470+
"""
468471
def __init__(self, iterable, key=None):
469472
if key is None:
470473
key = lambda x: x
@@ -487,3 +490,71 @@ def _grouper(self, tgtkey):
487490
yield self._currvalue
488491
self._currvalue = next(self._iter) # Exit on StopIteration
489492
self._currkey = self._keyfunc(self._currvalue)
493+
494+
495+
class combinations():
496+
"""
497+
combinations(iterable, r) --> combinations object
498+
499+
Return successive r-length combinations of elements in the iterable.
500+
501+
combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)
502+
"""
503+
504+
def __init__(self, pool, indices, r):
505+
self.pool = pool
506+
self.indices = range(indices)
507+
if r < 0:
508+
raise ValueError("r must be non-negative")
509+
self.r = r
510+
self.last_result = None
511+
self.stopped = r > len(pool)
512+
513+
def get_maximum(self, i):
514+
return i + len(self.pool) - self.r
515+
516+
def max_index(self, j):
517+
return self.indices[j - 1] + 1
518+
519+
def __iter__(self):
520+
return self
521+
522+
def __next__(self):
523+
if self.stopped:
524+
raise StopIteration
525+
if self.last_result is None:
526+
# On the first pass, initialize result tuple using the indices
527+
result = [None] * self.r
528+
for i in range(self.r):
529+
index = self.indices[i]
530+
result[i] = self.pool[index]
531+
else:
532+
# Copy the previous result
533+
result = self.last_result[:]
534+
# Scan indices right-to-left until finding one that is not at its
535+
# maximum
536+
i = self.r - 1
537+
while i >= 0 and self.indices[i] == self.get_maximum(i):
538+
i -= 1
539+
540+
# If i is negative, then the indices are all at their maximum value
541+
# and we're done
542+
if i < 0:
543+
self.stopped = True
544+
raise StopIteration
545+
546+
# Increment the current index which we know is not at its maximum.
547+
# Then move back to the right setting each index to its lowest
548+
# possible value
549+
self.indices[i] += 1
550+
for j in range(i + 1, self.r):
551+
self.indices[j] = self.max_index(j)
552+
553+
# Update the result for the new indices starting with i, the
554+
# leftmost index that changed
555+
for i in range(i, self.r):
556+
index = self.indices[i]
557+
elem = self.pool[index]
558+
result[i] = elem
559+
self.last_result = result
560+
return tuple(result)

0 commit comments

Comments
 (0)