Skip to content

Commit e68a79c

Browse files
authored
Validate circular reference for yaml (#33208)
1 parent 286e29c commit e68a79c

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

sdks/python/apache_beam/yaml/yaml_transform.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,21 @@ def preprocess_languages(spec):
956956
else:
957957
return spec
958958

959+
def validate_transform_references(spec):
960+
name = spec.get('name', '')
961+
transform_type = spec.get('type')
962+
inputs = spec.get('input').get('input', [])
963+
964+
if not is_empty(inputs):
965+
input_values = [inputs] if isinstance(inputs, str) else inputs
966+
for input_value in input_values:
967+
if input_value in (name, transform_type):
968+
raise ValueError(
969+
f"Circular reference detected: Transform {name} "
970+
f"references itself as input in {identify_object(spec)}")
971+
972+
return spec
973+
959974
for phase in [
960975
ensure_transforms_have_types,
961976
normalize_mapping,
@@ -966,6 +981,7 @@ def preprocess_languages(spec):
966981
preprocess_chain,
967982
tag_explicit_inputs,
968983
normalize_inputs_outputs,
984+
validate_transform_references,
969985
preprocess_flattened_inputs,
970986
ensure_errors_consumed,
971987
preprocess_windowing,

sdks/python/apache_beam/yaml/yaml_transform_test.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,51 @@ def test_csv_to_json(self):
259259
lines=True).sort_values('rank').reindex()
260260
pd.testing.assert_frame_equal(data, result)
261261

262+
def test_circular_reference_validation(self):
263+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
264+
pickle_library='cloudpickle')) as p:
265+
# pylint: disable=expression-not-assigned
266+
with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
267+
p | YamlTransform(
268+
'''
269+
type: composite
270+
transforms:
271+
- type: Create
272+
name: Create
273+
config:
274+
elements: [0, 1, 3, 4]
275+
input: Create
276+
- type: PyMap
277+
name: PyMap
278+
config:
279+
fn: "lambda row: row.element * row.element"
280+
input: Create
281+
output: PyMap
282+
''',
283+
providers=TEST_PROVIDERS)
284+
285+
def test_circular_reference_multi_inputs_validation(self):
286+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
287+
pickle_library='cloudpickle')) as p:
288+
# pylint: disable=expression-not-assigned
289+
with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
290+
p | YamlTransform(
291+
'''
292+
type: composite
293+
transforms:
294+
- type: Create
295+
name: Create
296+
config:
297+
elements: [0, 1, 3, 4]
298+
- type: PyMap
299+
name: PyMap
300+
config:
301+
fn: "lambda row: row.element * row.element"
302+
input: [Create, PyMap]
303+
output: PyMap
304+
''',
305+
providers=TEST_PROVIDERS)
306+
262307
def test_name_is_not_ambiguous(self):
263308
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
264309
pickle_library='cloudpickle')) as p:
@@ -285,7 +330,7 @@ def test_name_is_ambiguous(self):
285330
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
286331
pickle_library='cloudpickle')) as p:
287332
# pylint: disable=expression-not-assigned
288-
with self.assertRaisesRegex(ValueError, r'Ambiguous.*'):
333+
with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
289334
p | YamlTransform(
290335
'''
291336
type: composite

0 commit comments

Comments
 (0)