Skip to content

Commit 71b1e83

Browse files
authored
fix provider testing issue (#37183)
1 parent 358e007 commit 71b1e83

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

sdks/python/apache_beam/yaml/yaml_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def json_config_schema(self, type):
490490
return dict(
491491
type='object',
492492
additionalProperties=False,
493-
**self._transforms[type]['config_schema'])
493+
**self._transforms[type].get('config_schema', {}))
494494

495495
def description(self, type):
496496
return self._transforms[type].get('description')

sdks/python/apache_beam/yaml/yaml_testing.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,15 @@ def __str__(self):
7373

7474
def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
7575
if isinstance(pipeline_spec, str):
76-
pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader)
76+
pipeline_spec_dict = yaml.load(
77+
pipeline_spec, Loader=yaml_utils.SafeLineLoader)
78+
else:
79+
pipeline_spec_dict = pipeline_spec
7780

78-
pipeline_spec = _preprocess_for_testing(pipeline_spec)
81+
processed_pipeline_spec = _preprocess_for_testing(pipeline_spec_dict)
7982

8083
transform_spec, recording_ids = inject_test_tranforms(
81-
pipeline_spec,
84+
processed_pipeline_spec,
8285
test_spec,
8386
fix_failures)
8487

@@ -96,12 +99,18 @@ def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
9699
options = beam.options.pipeline_options.PipelineOptions(
97100
pickle_library='cloudpickle',
98101
**yaml_transform.SafeLineLoader.strip_metadata(
99-
pipeline_spec.get('options', {})))
102+
pipeline_spec_dict.get('options', {})))
103+
104+
providers = yaml_provider.merge_providers(
105+
yaml_provider.parse_providers(
106+
'', pipeline_spec_dict.get('providers', [])),
107+
{
108+
'AssertEqualAndRecord': yaml_provider.as_provider_list(
109+
'AssertEqualAndRecord', AssertEqualAndRecord)
110+
})
100111

101112
with beam.Pipeline(options=options) as p:
102-
_ = p | yaml_transform.YamlTransform(
103-
transform_spec,
104-
providers={'AssertEqualAndRecord': AssertEqualAndRecord})
113+
_ = p | yaml_transform.YamlTransform(transform_spec, providers=providers)
105114

106115
if fix_failures:
107116
fixes = {}

sdks/python/apache_beam/yaml/yaml_testing_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,40 @@ def test_join_transform_serialization(self):
322322
}]
323323
})
324324

325+
def test_toplevel_providers(self):
326+
yaml_testing.run_test(
327+
'''
328+
pipeline:
329+
type: chain
330+
transforms:
331+
- type: Create
332+
config:
333+
elements: [1, 2, 3]
334+
- type: MyDoubler
335+
providers:
336+
- type: yaml
337+
transforms:
338+
MyDoubler:
339+
body:
340+
type: MapToFields
341+
config:
342+
language: python
343+
fields:
344+
doubled: element * 2
345+
''',
346+
{
347+
'expected_outputs': [{
348+
'name': 'MyDoubler',
349+
'elements': [{
350+
'doubled': 2
351+
}, {
352+
'doubled': 4
353+
}, {
354+
'doubled': 6
355+
}]
356+
}]
357+
})
358+
325359

326360
if __name__ == '__main__':
327361
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)