Skip to content

Commit e0b9d03

Browse files
authored
Merge pull request #34073 Allow declaration of external dependencies for YAML UDFs.
2 parents 5a8964c + 18f237e commit e0b9d03

File tree

12 files changed

+290
-22
lines changed

12 files changed

+290
-22
lines changed

sdks/python/apache_beam/options/pipeline_options.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,14 @@ def get_all_options(
419419
_LOGGER.warning(
420420
'Unknown pipeline options received: %s. Ignore if flags are '
421421
'used for internal purposes.' % (','.join(unknown_args)))
422+
423+
seen = set()
424+
425+
def add_new_arg(arg, **kwargs):
426+
if arg not in seen:
427+
parser.add_argument(arg, **kwargs)
428+
seen.add(arg)
429+
422430
i = 0
423431
while i < len(unknown_args):
424432
# End of argument parsing.
@@ -432,12 +440,12 @@ def get_all_options(
432440
if i + 1 >= len(unknown_args) or unknown_args[i + 1].startswith('-'):
433441
split = unknown_args[i].split('=', 1)
434442
if len(split) == 1:
435-
parser.add_argument(unknown_args[i], action='store_true')
443+
add_new_arg(unknown_args[i], action='store_true')
436444
else:
437-
parser.add_argument(split[0], type=str)
445+
add_new_arg(split[0], type=str)
438446
i += 1
439447
elif unknown_args[i].startswith('--'):
440-
parser.add_argument(unknown_args[i], type=str)
448+
add_new_arg(unknown_args[i], type=str)
441449
i += 2
442450
else:
443451
# skip all binary flags used with '-' and not '--'.

sdks/python/apache_beam/runners/portability/expansion_service_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def main(argv):
5555
with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter(
5656
known_args.fully_qualified_name_glob):
5757

58-
address = '[::]:{}'.format(known_args.port)
58+
address = 'localhost:{}'.format(known_args.port)
5959
server = grpc.server(thread_pool_executor.shared_unbounded_instance())
6060
if known_args.serve_loopback_worker:
6161
beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(

sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,7 @@ def start_worker(self):
620620
stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub(
621621
GRPCChannelFactory.insecure_channel(
622622
self._external_payload.endpoint.url))
623+
_LOGGER.info('self.control_address: %s' % self.control_address)
623624
control_descriptor = endpoints_pb2.ApiServiceDescriptor(
624625
url=self.control_address)
625626
response = stub.StartWorker(

sdks/python/apache_beam/utils/subprocess_server.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,21 +335,24 @@ def path_to_maven_jar(
335335
])
336336

337337
@classmethod
338-
def path_to_beam_jar(
338+
def parse_gradle_target(cls, gradle_target, artifact_id=None):
339+
gradle_package = gradle_target.strip(':').rsplit(':', 1)[0]
340+
if not artifact_id:
341+
artifact_id = 'beam-' + gradle_package.replace(':', '-')
342+
return gradle_package, artifact_id
343+
344+
@classmethod
345+
def path_to_dev_beam_jar(
339346
cls,
340347
gradle_target,
341348
appendix=None,
342349
version=beam_version,
343350
artifact_id=None):
344-
if gradle_target in cls._BEAM_SERVICES.replacements:
345-
return cls._BEAM_SERVICES.replacements[gradle_target]
346-
347-
gradle_package = gradle_target.strip(':').rsplit(':', 1)[0]
348-
if not artifact_id:
349-
artifact_id = 'beam-' + gradle_package.replace(':', '-')
351+
gradle_package, artifact_id = cls.parse_gradle_target(
352+
gradle_target, artifact_id)
350353
project_root = os.path.sep.join(
351354
os.path.abspath(__file__).split(os.path.sep)[:-5])
352-
local_path = os.path.join(
355+
return os.path.join(
353356
project_root,
354357
gradle_package.replace(':', os.path.sep),
355358
'build',
@@ -359,6 +362,22 @@ def path_to_beam_jar(
359362
version.replace('.dev', ''),
360363
classifier='SNAPSHOT',
361364
appendix=appendix))
365+
366+
@classmethod
367+
def path_to_beam_jar(
368+
cls,
369+
gradle_target,
370+
appendix=None,
371+
version=beam_version,
372+
artifact_id=None):
373+
if gradle_target in cls._BEAM_SERVICES.replacements:
374+
return cls._BEAM_SERVICES.replacements[gradle_target]
375+
376+
_, artifact_id = cls.parse_gradle_target(gradle_target, artifact_id)
377+
project_root = os.path.sep.join(
378+
os.path.abspath(__file__).split(os.path.sep)[:-5])
379+
local_path = cls.path_to_dev_beam_jar(
380+
gradle_target, appendix, version, artifact_id)
362381
if os.path.exists(local_path):
363382
_LOGGER.info('Using pre-built snapshot at %s', local_path)
364383
return local_path

sdks/python/apache_beam/yaml/examples/testing/examples_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import os
2222
import random
23+
import sys
2324
import unittest
2425
from typing import Any
2526
from typing import Callable
@@ -29,13 +30,15 @@
2930
from typing import Union
3031
from unittest import mock
3132

33+
import pytest
3234
import yaml
3335

3436
import apache_beam as beam
3537
from apache_beam import PCollection
3638
from apache_beam.examples.snippets.util import assert_matches_stdout
3739
from apache_beam.options.pipeline_options import PipelineOptions
3840
from apache_beam.testing.test_pipeline import TestPipeline
41+
from apache_beam.utils import subprocess_server
3942
from apache_beam.yaml import yaml_provider
4043
from apache_beam.yaml import yaml_transform
4144
from apache_beam.yaml.readme_test import TestEnvironment
@@ -263,6 +266,30 @@ def test_yaml_example(self):
263266
actual += list(transform.outputs.values())
264267
check_output(expected)(actual)
265268

269+
if 'deps' in pipeline_spec_file:
270+
test_yaml_example = pytest.mark.no_xdist(test_yaml_example)
271+
test_yaml_example = unittest.skipIf(
272+
sys.platform == 'win32', "Github virtualenv permissions issues.")(
273+
test_yaml_example)
274+
# This test fails, with an import error, for some (but not all) cloud
275+
# tox environments when run as a github action (not reproducible locally).
276+
# Adding debugging makes the failure go away. All indications are that
277+
# this is some testing environmental issue.
278+
test_yaml_example = unittest.skipIf(
279+
'-cloud' in os.environ.get('TOX_ENV_NAME', ''),
280+
'Github actions environment issue.')(
281+
test_yaml_example)
282+
283+
if 'java_deps' in pipeline_spec_file:
284+
test_yaml_example = pytest.mark.xlang_sql_expansion_service(
285+
test_yaml_example)
286+
test_yaml_example = unittest.skipIf(
287+
not os.path.exists(
288+
subprocess_server.JavaJarServer.path_to_dev_beam_jar(
289+
'sdks:java:extensions:sql:expansion-service:shadowJar')),
290+
"Requires expansion service jars.")(
291+
test_yaml_example)
292+
266293
return test_yaml_example
267294

268295

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# coding=utf-8
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one or more
4+
# contributor license agreements. See the NOTICE file distributed with
5+
# this work for additional information regarding copyright ownership.
6+
# The ASF licenses this file to You under the Apache License, Version 2.0
7+
# (the "License"); you may not use this file except in compliance with
8+
# the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
pipeline:
20+
type: chain
21+
transforms:
22+
- type: Create
23+
config:
24+
elements:
25+
- {sdk: MapReduce, year: 2004}
26+
- {sdk: MillWheel, year: 2008}
27+
- {sdk: Flume, year: 2010}
28+
- {sdk: Dataflow, year: 2014}
29+
- {sdk: Apache Beam, year: 2016}
30+
- type: MapToFields
31+
name: ToRoman
32+
config:
33+
language: python
34+
fields:
35+
tool_name: sdk
36+
year:
37+
callable: |
38+
import roman
39+
40+
def convert(row):
41+
return roman.toRoman(row.year)
42+
dependencies:
43+
- 'roman>=4.2'
44+
- type: LogForTesting
45+
46+
# Expected:
47+
# Row(tool_name='MapReduce', year='MMIV')
48+
# Row(tool_name='MillWheel', year='MMVIII')
49+
# Row(tool_name='Flume', year='MMX')
50+
# Row(tool_name='Dataflow', year='MMXIV')
51+
# Row(tool_name='Apache Beam', year='MMXVI')
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# coding=utf-8
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one or more
4+
# contributor license agreements. See the NOTICE file distributed with
5+
# this work for additional information regarding copyright ownership.
6+
# The ASF licenses this file to You under the Apache License, Version 2.0
7+
# (the "License"); you may not use this file except in compliance with
8+
# the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
pipeline:
20+
type: chain
21+
transforms:
22+
- type: Create
23+
config:
24+
elements:
25+
- {sdk: MapReduce, year: 2004}
26+
- {sdk: MillWheel, year: 2008}
27+
- {sdk: Flume, year: 2010}
28+
- {sdk: Dataflow, year: 2014}
29+
- {sdk: Apache Beam, year: 2016}
30+
- type: MapToFields
31+
name: ToRoman
32+
config:
33+
language: java
34+
fields:
35+
tool_name: sdk
36+
year:
37+
callable: |
38+
import org.apache.beam.sdk.values.Row;
39+
import java.util.function.Function;
40+
import com.github.chaosfirebolt.converter.RomanInteger;
41+
42+
public class MyFunction implements Function<Row, String> {
43+
public String apply(Row row) {
44+
return RomanInteger.parse(
45+
String.valueOf(row.getInt64("year"))).toString();
46+
}
47+
}
48+
dependencies:
49+
- 'com.github.chaosfirebolt.converter:roman-numeral-converter:2.1.0'
50+
- type: LogForTesting
51+
52+
# Expected:
53+
# Row(tool_name='MapReduce', year='MMIV')
54+
# Row(tool_name='MillWheel', year='MMVIII')
55+
# Row(tool_name='Flume', year='MMX')
56+
# Row(tool_name='Dataflow', year='MMXIV')
57+
# Row(tool_name='Apache Beam', year='MMXVI')

sdks/python/apache_beam/yaml/yaml_enrichment_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from apache_beam import Row
2525
from apache_beam.testing.util import assert_that
2626
from apache_beam.testing.util import equal_to
27+
from apache_beam.yaml import yaml_provider
2728
from apache_beam.yaml.yaml_transform import YamlTransform
2829

2930

@@ -59,6 +60,8 @@ def test_enrichment_with_bigquery(self):
5960
with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform',
6061
FakeEnrichmentTransform(enrichment_handler=handler,
6162
handler_config=config)):
63+
# Force a reload to respect our mock.
64+
yaml_provider.standard_providers.cache_clear()
6265
input_pcoll = p | 'CreateInput' >> beam.Create(input_data)
6366
result = input_pcoll | YamlTransform(
6467
f'''

sdks/python/apache_beam/yaml/yaml_mapping.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,8 @@ def _PyJsMapToFields(
678678
fields: Mapping[str, Union[str, Mapping[str, str]]],
679679
append: Optional[bool] = False,
680680
drop: Optional[Iterable[str]] = None,
681-
language: Optional[str] = None):
681+
language: Optional[str] = None,
682+
dependencies: Optional[Iterable[str]] = None):
682683
"""Creates records with new fields defined in terms of the input fields.
683684
684685
See more complete documentation on
@@ -694,6 +695,8 @@ def _PyJsMapToFields(
694695
original record that should not be kept
695696
language: The language used to define (and execute) the
696697
expressions and/or callables in `fields`. Defaults to generic.
698+
dependencies: An optional list of extra dependencies that are needed for
699+
these UDFs. The interpretation of these strings is language-dependent.
697700
error_handling: Whether and where to output records that throw errors when
698701
the above expressions are evaluated.
699702
""" # pylint: disable=line-too-long

0 commit comments

Comments
 (0)