Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions sdks/python/apache_beam/options/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,14 @@ def get_all_options(
_LOGGER.warning(
'Unknown pipeline options received: %s. Ignore if flags are '
'used for internal purposes.' % (','.join(unknown_args)))

seen = set()

def add_new_arg(arg, **kwargs):
if arg not in seen:
parser.add_argument(arg, **kwargs)
seen.add(arg)

i = 0
while i < len(unknown_args):
# End of argument parsing.
Expand All @@ -432,12 +440,12 @@ def get_all_options(
if i + 1 >= len(unknown_args) or unknown_args[i + 1].startswith('-'):
split = unknown_args[i].split('=', 1)
if len(split) == 1:
parser.add_argument(unknown_args[i], action='store_true')
add_new_arg(unknown_args[i], action='store_true')
else:
parser.add_argument(split[0], type=str)
add_new_arg(split[0], type=str)
i += 1
elif unknown_args[i].startswith('--'):
parser.add_argument(unknown_args[i], type=str)
add_new_arg(unknown_args[i], type=str)
i += 2
else:
# skip all binary flags used with '-' and not '--'.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main(argv):
with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter(
known_args.fully_qualified_name_glob):

address = '[::]:{}'.format(known_args.port)
address = 'localhost:{}'.format(known_args.port)
server = grpc.server(thread_pool_executor.shared_unbounded_instance())
if known_args.serve_loopback_worker:
beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ def start_worker(self):
stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub(
GRPCChannelFactory.insecure_channel(
self._external_payload.endpoint.url))
_LOGGER.info('self.control_address: %s' % self.control_address)
control_descriptor = endpoints_pb2.ApiServiceDescriptor(
url=self.control_address)
response = stub.StartWorker(
Expand Down
35 changes: 27 additions & 8 deletions sdks/python/apache_beam/utils/subprocess_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,21 +335,24 @@ def path_to_maven_jar(
])

@classmethod
def path_to_beam_jar(
def parse_gradle_target(cls, gradle_target, artifact_id=None):
gradle_package = gradle_target.strip(':').rsplit(':', 1)[0]
if not artifact_id:
artifact_id = 'beam-' + gradle_package.replace(':', '-')
return gradle_package, artifact_id

@classmethod
def path_to_dev_beam_jar(
cls,
gradle_target,
appendix=None,
version=beam_version,
artifact_id=None):
if gradle_target in cls._BEAM_SERVICES.replacements:
return cls._BEAM_SERVICES.replacements[gradle_target]

gradle_package = gradle_target.strip(':').rsplit(':', 1)[0]
if not artifact_id:
artifact_id = 'beam-' + gradle_package.replace(':', '-')
gradle_package, artifact_id = cls.parse_gradle_target(
gradle_target, artifact_id)
project_root = os.path.sep.join(
os.path.abspath(__file__).split(os.path.sep)[:-5])
local_path = os.path.join(
return os.path.join(
project_root,
gradle_package.replace(':', os.path.sep),
'build',
Expand All @@ -359,6 +362,22 @@ def path_to_beam_jar(
version.replace('.dev', ''),
classifier='SNAPSHOT',
appendix=appendix))

@classmethod
def path_to_beam_jar(
cls,
gradle_target,
appendix=None,
version=beam_version,
artifact_id=None):
if gradle_target in cls._BEAM_SERVICES.replacements:
return cls._BEAM_SERVICES.replacements[gradle_target]

_, artifact_id = cls.parse_gradle_target(gradle_target, artifact_id)
project_root = os.path.sep.join(
os.path.abspath(__file__).split(os.path.sep)[:-5])
local_path = cls.path_to_dev_beam_jar(
gradle_target, appendix, version, artifact_id)
if os.path.exists(local_path):
_LOGGER.info('Using pre-built snapshot at %s', local_path)
return local_path
Expand Down
27 changes: 27 additions & 0 deletions sdks/python/apache_beam/yaml/examples/testing/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import os
import random
import sys
import unittest
from typing import Any
from typing import Callable
Expand All @@ -29,13 +30,15 @@
from typing import Union
from unittest import mock

import pytest
import yaml

import apache_beam as beam
from apache_beam import PCollection
from apache_beam.examples.snippets.util import assert_matches_stdout
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.utils import subprocess_server
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
from apache_beam.yaml.readme_test import TestEnvironment
Expand Down Expand Up @@ -263,6 +266,30 @@ def test_yaml_example(self):
actual += list(transform.outputs.values())
check_output(expected)(actual)

if 'deps' in pipeline_spec_file:
test_yaml_example = pytest.mark.no_xdist(test_yaml_example)
test_yaml_example = unittest.skipIf(
sys.platform == 'win32', "Github virtualenv permissions issues.")(
test_yaml_example)
# This test fails, with an import error, for some (but not all) cloud
# tox environments when run as a github action (not reproducible locally).
# Adding debugging makes the failure go away. All indications are that
# this is some testing environmental issue.
test_yaml_example = unittest.skipIf(
'-cloud' in os.environ.get('TOX_ENV_NAME', ''),
'Github actions environment issue.')(
test_yaml_example)

if 'java_deps' in pipeline_spec_file:
test_yaml_example = pytest.mark.xlang_sql_expansion_service(
test_yaml_example)
test_yaml_example = unittest.skipIf(
not os.path.exists(
subprocess_server.JavaJarServer.path_to_dev_beam_jar(
'sdks:java:extensions:sql:expansion-service:shadowJar')),
"Requires expansion service jars.")(
test_yaml_example)

return test_yaml_example


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- {sdk: MapReduce, year: 2004}
- {sdk: MillWheel, year: 2008}
- {sdk: Flume, year: 2010}
- {sdk: Dataflow, year: 2014}
- {sdk: Apache Beam, year: 2016}
- type: MapToFields
name: ToRoman
config:
language: python
fields:
tool_name: sdk
year:
callable: |
import roman

def convert(row):
return roman.toRoman(row.year)
dependencies:
- 'roman>=4.2'
- type: LogForTesting

# Expected:
# Row(tool_name='MapReduce', year='MMIV')
# Row(tool_name='MillWheel', year='MMVIII')
# Row(tool_name='Flume', year='MMX')
# Row(tool_name='Dataflow', year='MMXIV')
# Row(tool_name='Apache Beam', year='MMXVI')
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- {sdk: MapReduce, year: 2004}
- {sdk: MillWheel, year: 2008}
- {sdk: Flume, year: 2010}
- {sdk: Dataflow, year: 2014}
- {sdk: Apache Beam, year: 2016}
- type: MapToFields
name: ToRoman
config:
language: java
fields:
tool_name: sdk
year:
callable: |
import org.apache.beam.sdk.values.Row;
import java.util.function.Function;
import com.github.chaosfirebolt.converter.RomanInteger;

public class MyFunction implements Function<Row, String> {
public String apply(Row row) {
return RomanInteger.parse(
String.valueOf(row.getInt64("year"))).toString();
}
}
dependencies:
- 'com.github.chaosfirebolt.converter:roman-numeral-converter:2.1.0'
- type: LogForTesting

# Expected:
# Row(tool_name='MapReduce', year='MMIV')
# Row(tool_name='MillWheel', year='MMVIII')
# Row(tool_name='Flume', year='MMX')
# Row(tool_name='Dataflow', year='MMXIV')
# Row(tool_name='Apache Beam', year='MMXVI')
3 changes: 3 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_enrichment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from apache_beam import Row
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_transform import YamlTransform


Expand Down Expand Up @@ -59,6 +60,8 @@ def test_enrichment_with_bigquery(self):
with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform',
FakeEnrichmentTransform(enrichment_handler=handler,
handler_config=config)):
# Force a reload to respect our mock.
yaml_provider.standard_providers.cache_clear()
input_pcoll = p | 'CreateInput' >> beam.Create(input_data)
result = input_pcoll | YamlTransform(
f'''
Expand Down
5 changes: 4 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,8 @@ def _PyJsMapToFields(
fields: Mapping[str, Union[str, Mapping[str, str]]],
append: Optional[bool] = False,
drop: Optional[Iterable[str]] = None,
language: Optional[str] = None):
language: Optional[str] = None,
dependencies: Optional[Iterable[str]] = None):
"""Creates records with new fields defined in terms of the input fields.

See more complete documentation on
Expand All @@ -694,6 +695,8 @@ def _PyJsMapToFields(
original record that should not be kept
language: The language used to define (and execute) the
expressions and/or callables in `fields`. Defaults to generic.
dependencies: An optional list of extra dependencies that are needed for
these UDFs. The interpretation of these strings is language-dependent.
error_handling: Whether and where to output records that throw errors when
the above expressions are evaluated.
""" # pylint: disable=line-too-long
Expand Down
Loading
Loading