Skip to content

Commit db08b7c

Browse files
authored
Type inference tests (#36776)
* Document _infer_result_type * More tests. * Comments.
1 parent 61b8f41 commit db08b7c

File tree

2 files changed

+195
-0
lines changed

2 files changed

+195
-0
lines changed

sdks/python/apache_beam/pipeline.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,100 @@ def _infer_result_type(
866866
transform: ptransform.PTransform,
867867
inputs: Sequence[Union[pvalue.PBegin, pvalue.PCollection]],
868868
result_pcollection: Union[pvalue.PValue, pvalue.DoOutputsTuple]) -> None:
869+
"""Infer and set the output element type for a PCollection.
870+
871+
This function determines the output types of transforms by combining:
872+
1. Concrete input types from previous transforms
873+
2. Type hints declared on the current transform
874+
3. Type variable binding and substitution
875+
876+
TYPE VARIABLE BINDING
877+
---------------------
878+
Type variables (K, V, T, etc.) act as placeholders that get bound to
879+
concrete types through pattern matching. This requires both an input
880+
pattern and an output template:
881+
882+
Input Pattern (from .with_input_types()):
883+
Defines where in the input to find each type variable
884+
Example: Tuple[K, V] means "K is the first element, V is the second"
885+
886+
Output Template (from .with_output_types()):
887+
Defines how to use the bound variables in the output
888+
Example: Tuple[V, K] means "swap the positions"
889+
890+
CONCRETE TYPES VS TYPE VARIABLES
891+
---------------------------------
892+
The system handles these differently:
893+
894+
Concrete Types (e.g., str, int, Tuple[str, int]):
895+
- Used as-is without any binding
896+
- Do not fall back to Any
897+
- Example: .with_output_types(Tuple[str, int]) → Tuple[str, int]
898+
899+
Type Variables (e.g., K, V, T):
900+
- Must be bound through pattern matching
901+
- Require .with_input_types() to provide the pattern
902+
- Fall back to Any if not bound
903+
- Example without pattern: Tuple[K, V] → Tuple[Any, Any]
904+
- Example with pattern: Tuple[K, V] → Tuple[str, int]
905+
906+
BINDING ALGORITHM
907+
-----------------
908+
1. Match: Compare input pattern to concrete input
909+
Pattern: Tuple[K, V]
910+
Concrete: Tuple[str, int]
911+
Result: {K: str, V: int} ← Bindings created
912+
913+
2. Substitute: Apply bindings to output template
914+
Template: Tuple[V, K] ← Note: swapped!
915+
Bindings: {K: str, V: int}
916+
Result: Tuple[int, str] ← Swapped concrete types
917+
918+
Each transform operates in its own type inference scope. Type variables
919+
declared in a parent composite transform do NOT automatically propagate
920+
to child transforms.
921+
922+
Parent scope (composite):
923+
@with_input_types(Tuple[K, V]) ← K, V defined here
924+
class MyComposite(PTransform):
925+
def expand(self, pcoll):
926+
# Child scope - parent's K, V are NOT available
927+
return pcoll | ChildTransform()
928+
929+
Type variables that remain unbound after inference fall back to Any:
930+
931+
EXAMPLES
932+
--------
933+
Example 1: Concrete types (no variables)
934+
Input: Tuple[str, int]
935+
Transform: .with_output_types(Tuple[str, int])
936+
Output: Tuple[str, int] ← Used as-is
937+
938+
Example 2: Type variables with pattern (correct)
939+
Input: Tuple[str, int]
940+
Transform: .with_input_types(Tuple[K, V])
941+
.with_output_types(Tuple[V, K])
942+
Binding: {K: str, V: int}
943+
Output: Tuple[int, str] ← Swapped!
944+
945+
Example 3: Type variables without pattern (falls back to Any)
946+
Input: Tuple[str, int]
947+
Transform: .with_output_types(Tuple[K, V]) ← No input pattern!
948+
Binding: None (can't match)
949+
Output: Tuple[Any, Any] ← Fallback
950+
951+
Example 4: Mixed concrete and variables
952+
Input: Tuple[str, int]
953+
Transform: .with_input_types(Tuple[str, V])
954+
.with_output_types(Tuple[str, V])
955+
Binding: {V: int} ← Only V needs binding
956+
Output: Tuple[str, int] ← str passed through, V bound to int
957+
958+
Args:
959+
transform: The PTransform being applied
960+
inputs: Input PCollections (provides concrete types)
961+
result_pcollection: Output PCollection to set type on
962+
"""
869963
# TODO(robertwb): Multi-input inference.
870964
type_options = self._options.view_as(TypeOptions)
871965
if type_options is None or not type_options.pipeline_type_check:
@@ -881,6 +975,7 @@ def _infer_result_type(
881975
else typehints.Union[input_element_types_tuple])
882976
type_hints = transform.get_type_hints()
883977
declared_output_type = type_hints.simple_output_type(transform.label)
978+
884979
if declared_output_type:
885980
input_types = type_hints.input_types
886981
if input_types and input_types[0]:
@@ -893,6 +988,7 @@ def _infer_result_type(
893988
result_element_type = declared_output_type
894989
else:
895990
result_element_type = transform.infer_output_type(input_element_type)
991+
896992
# Any remaining type variables have no bindings higher than this scope.
897993
result_pcollection.element_type = typehints.bind_type_variables(
898994
result_element_type, {'*': typehints.Any})

sdks/python/apache_beam/transforms/ptransform_test.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,105 @@ def process(self, element, five):
14021402
assert_that(d, equal_to([6, 7, 8]))
14031403
self.p.run()
14041404

1405+
def test_child_with_both_input_and_output_hints_binds_typevars_correctly(
1406+
self):
1407+
"""
1408+
When a child transform has both input and output type hints with type
1409+
variables, those variables bind correctly from the actual input data.
1410+
1411+
Example: Child with .with_input_types(Tuple[K, V])
1412+
.with_output_types(Tuple[K, V]) receiving Tuple['a', 'hello'] will bind
1413+
K=str, V=str correctly.
1414+
"""
1415+
K = typehints.TypeVariable('K')
1416+
V = typehints.TypeVariable('V')
1417+
1418+
@typehints.with_input_types(typehints.Tuple[K, V])
1419+
@typehints.with_output_types(typehints.Tuple[K, V])
1420+
class TransformWithoutChildHints(beam.PTransform):
1421+
class MyDoFn(beam.DoFn):
1422+
def process(self, element):
1423+
k, v = element
1424+
yield (k, v.upper())
1425+
1426+
def expand(self, pcoll):
1427+
return (
1428+
pcoll
1429+
| beam.ParDo(self.MyDoFn()).with_input_types(
1430+
tuple[K, V]).with_output_types(tuple[K, V]))
1431+
1432+
with TestPipeline() as p:
1433+
result = (
1434+
p
1435+
| beam.Create([('a', 'hello'), ('b', 'world')])
1436+
| TransformWithoutChildHints())
1437+
1438+
self.assertEqual(result.element_type, typehints.Tuple[str, str])
1439+
1440+
def test_child_without_input_hints_fails_to_bind_typevars(self):
1441+
"""
1442+
When a child transform lacks input type hints, type variables in its output
1443+
hints cannot bind and default to Any, even when parent composite has
1444+
decorated type hints.
1445+
1446+
This test demonstrates the current limitation: without explicit input hints
1447+
on the child, the type variable K in .with_output_types(Tuple[K, str])
1448+
remains unbound, resulting in Tuple[Any, str] instead of the expected
1449+
Tuple[str, str].
1450+
"""
1451+
K = typehints.TypeVariable('K')
1452+
1453+
@typehints.with_input_types(typehints.Tuple[K, str])
1454+
@typehints.with_output_types(typehints.Tuple[K, str])
1455+
class TransformWithoutChildHints(beam.PTransform):
1456+
class MyDoFn(beam.DoFn):
1457+
def process(self, element):
1458+
k, v = element
1459+
yield (k, v.upper())
1460+
1461+
def expand(self, pcoll):
1462+
return (
1463+
pcoll
1464+
| beam.ParDo(self.MyDoFn()).with_output_types(tuple[K, str]))
1465+
1466+
with TestPipeline() as p:
1467+
result = (
1468+
p
1469+
| beam.Create([('a', 'hello'), ('b', 'world')])
1470+
| TransformWithoutChildHints())
1471+
1472+
self.assertEqual(result.element_type, typehints.Tuple[typehints.Any, str])
1473+
1474+
def test_child_without_output_hints_infers_partial_types_from_dofn(self):
1475+
"""
1476+
When a child transform has input hints but no output hints, type inference
1477+
from the DoFn's process method produces partially inferred types.
1478+
1479+
Type inference is able to infer the first element of the tuple as str, but
1480+
not the v.upper() and falls back to any.
1481+
"""
1482+
K = typehints.TypeVariable('K')
1483+
V = typehints.TypeVariable('V')
1484+
1485+
@typehints.with_input_types(typehints.Tuple[K, V])
1486+
@typehints.with_output_types(typehints.Tuple[K, V])
1487+
class TransformWithoutChildHints(beam.PTransform):
1488+
class MyDoFn(beam.DoFn):
1489+
def process(self, element):
1490+
k, v = element
1491+
yield (k, v.upper())
1492+
1493+
def expand(self, pcoll):
1494+
return (pcoll | beam.ParDo(self.MyDoFn()).with_input_types(tuple[K, V]))
1495+
1496+
with TestPipeline() as p:
1497+
result = (
1498+
p
1499+
| beam.Create([('a', 'hello'), ('b', 'world')])
1500+
| TransformWithoutChildHints())
1501+
1502+
self.assertEqual(result.element_type, typehints.Tuple[str, typing.Any])
1503+
14051504
def test_do_fn_pipeline_pipeline_type_check_violated(self):
14061505
@with_input_types(str, str)
14071506
@with_output_types(str)

0 commit comments

Comments
 (0)