Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3577,6 +3577,10 @@ def process(self, element, partitionfn, n, *args, **kwargs):
raise ValueError(
'PartitionFn specified out-of-bounds partition index: '
'%d not in [0, %d)' % (partition, n))
if isinstance(partition, bool) or not isinstance(partition, int):
raise ValueError(
f"PartitionFn yielded a '{type(partition).__name__}' "
"when it should only yields integers")
# Each input is directed into the output that corresponds to the
# selected partition.
yield pvalue.TaggedOutput(str(partition), element)
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ def test_dofn_with_implicit_return_none_return_without_value(self):


class PartitionTest(unittest.TestCase):
def test_partition_with_bools(self):
with pytest.raises(
ValueError,
match="PartitionFn yielded a 'bool' when it should only yields integers"
):
with beam.testing.test_pipeline.TestPipeline() as p:
_ = (p | beam.Create([True]) | beam.Partition(lambda x, _: x, 2))

def test_partition_boundedness(self):
def partition_fn(val, num_partitions):
return val % num_partitions
Expand Down
Loading