diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index fd6b9c5a0503..6de15b90790d 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -419,6 +419,14 @@ def get_all_options( _LOGGER.warning( 'Unknown pipeline options received: %s. Ignore if flags are ' 'used for internal purposes.' % (','.join(unknown_args))) + + seen = set() + + def add_new_arg(arg, **kwargs): + if arg not in seen: + parser.add_argument(arg, **kwargs) + seen.add(arg) + i = 0 while i < len(unknown_args): # End of argument parsing. @@ -432,12 +440,12 @@ def get_all_options( if i + 1 >= len(unknown_args) or unknown_args[i + 1].startswith('-'): split = unknown_args[i].split('=', 1) if len(split) == 1: - parser.add_argument(unknown_args[i], action='store_true') + add_new_arg(unknown_args[i], action='store_true') else: - parser.add_argument(split[0], type=str) + add_new_arg(split[0], type=str) i += 1 elif unknown_args[i].startswith('--'): - parser.add_argument(unknown_args[i], type=str) + add_new_arg(unknown_args[i], type=str) i += 2 else: # skip all binary flags used with '-' and not '--'. diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_main.py b/sdks/python/apache_beam/runners/portability/expansion_service_main.py index 307f6bd54182..6b89cee6082e 100644 --- a/sdks/python/apache_beam/runners/portability/expansion_service_main.py +++ b/sdks/python/apache_beam/runners/portability/expansion_service_main.py @@ -55,7 +55,7 @@ def main(argv): with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter( known_args.fully_qualified_name_glob): - address = '[::]:{}'.format(known_args.port) + address = 'localhost:{}'.format(known_args.port) server = grpc.server(thread_pool_executor.shared_unbounded_instance()) if known_args.serve_loopback_worker: beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server( diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index d798e96d3aa3..e5c9e9c7ac89 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -620,6 +620,7 @@ def start_worker(self): stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub( GRPCChannelFactory.insecure_channel( self._external_payload.endpoint.url)) + _LOGGER.info('self.control_address: %s' % self.control_address) control_descriptor = endpoints_pb2.ApiServiceDescriptor( url=self.control_address) response = stub.StartWorker( diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index 043b00dc934b..efb27715cd82 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -335,21 +335,24 @@ def path_to_maven_jar( ]) @classmethod - def path_to_beam_jar( + def parse_gradle_target(cls, gradle_target, artifact_id=None): + gradle_package = gradle_target.strip(':').rsplit(':', 1)[0] + if not artifact_id: + artifact_id = 'beam-' + gradle_package.replace(':', '-') + return gradle_package, artifact_id + + @classmethod + def path_to_dev_beam_jar( cls, gradle_target, appendix=None, version=beam_version, artifact_id=None): - if gradle_target in cls._BEAM_SERVICES.replacements: - return cls._BEAM_SERVICES.replacements[gradle_target] - - gradle_package = gradle_target.strip(':').rsplit(':', 1)[0] - if not artifact_id: - artifact_id = 'beam-' + gradle_package.replace(':', '-') + gradle_package, artifact_id = cls.parse_gradle_target( + gradle_target, artifact_id) project_root = os.path.sep.join( os.path.abspath(__file__).split(os.path.sep)[:-5]) - local_path = os.path.join( + return os.path.join( project_root, gradle_package.replace(':', os.path.sep), 'build', @@ -359,6 +362,22 @@ def path_to_beam_jar( version.replace('.dev', ''), classifier='SNAPSHOT', appendix=appendix)) + + @classmethod + def path_to_beam_jar( + cls, + gradle_target, + appendix=None, + version=beam_version, + artifact_id=None): + if gradle_target in cls._BEAM_SERVICES.replacements: + return cls._BEAM_SERVICES.replacements[gradle_target] + + _, artifact_id = cls.parse_gradle_target(gradle_target, artifact_id) + project_root = os.path.sep.join( + os.path.abspath(__file__).split(os.path.sep)[:-5]) + local_path = cls.path_to_dev_beam_jar( + gradle_target, appendix, version, artifact_id) if os.path.exists(local_path): _LOGGER.info('Using pre-built snapshot at %s', local_path) return local_path diff --git a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py index 109e98410852..ee35e3430766 100644 --- a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py +++ b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py @@ -20,6 +20,7 @@ import logging import os import random +import sys import unittest from typing import Any from typing import Callable @@ -29,6 +30,7 @@ from typing import Union from unittest import mock +import pytest import yaml import apache_beam as beam @@ -36,6 +38,7 @@ from apache_beam.examples.snippets.util import assert_matches_stdout from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.utils import subprocess_server from apache_beam.yaml import yaml_provider from apache_beam.yaml import yaml_transform from apache_beam.yaml.readme_test import TestEnvironment @@ -263,6 +266,30 @@ def test_yaml_example(self): actual += list(transform.outputs.values()) check_output(expected)(actual) + if 'deps' in pipeline_spec_file: + test_yaml_example = pytest.mark.no_xdist(test_yaml_example) + test_yaml_example = unittest.skipIf( + sys.platform == 'win32', "Github virtualenv permissions issues.")( + test_yaml_example) + # This test fails, with an import error, for some (but not all) cloud + # tox environments when run as a github action (not reproducible locally). + # Adding debugging makes the failure go away. All indications are that + # this is some testing environmental issue. + test_yaml_example = unittest.skipIf( + '-cloud' in os.environ.get('TOX_ENV_NAME', ''), + 'Github actions environment issue.')( + test_yaml_example) + + if 'java_deps' in pipeline_spec_file: + test_yaml_example = pytest.mark.xlang_sql_expansion_service( + test_yaml_example) + test_yaml_example = unittest.skipIf( + not os.path.exists( + subprocess_server.JavaJarServer.path_to_dev_beam_jar( + 'sdks:java:extensions:sql:expansion-service:shadowJar')), + "Requires expansion service jars.")( + test_yaml_example) + return test_yaml_example diff --git a/sdks/python/apache_beam/yaml/examples/transforms/elementwise/map_to_fields_with_deps.yaml b/sdks/python/apache_beam/yaml/examples/transforms/elementwise/map_to_fields_with_deps.yaml new file mode 100644 index 000000000000..a45f6ceb98c0 --- /dev/null +++ b/sdks/python/apache_beam/yaml/examples/transforms/elementwise/map_to_fields_with_deps.yaml @@ -0,0 +1,51 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - {sdk: MapReduce, year: 2004} + - {sdk: MillWheel, year: 2008} + - {sdk: Flume, year: 2010} + - {sdk: Dataflow, year: 2014} + - {sdk: Apache Beam, year: 2016} + - type: MapToFields + name: ToRoman + config: + language: python + fields: + tool_name: sdk + year: + callable: | + import roman + + def convert(row): + return roman.toRoman(row.year) + dependencies: + - 'roman>=4.2' + - type: LogForTesting + +# Expected: +# Row(tool_name='MapReduce', year='MMIV') +# Row(tool_name='MillWheel', year='MMVIII') +# Row(tool_name='Flume', year='MMX') +# Row(tool_name='Dataflow', year='MMXIV') +# Row(tool_name='Apache Beam', year='MMXVI') diff --git a/sdks/python/apache_beam/yaml/examples/transforms/elementwise/map_to_fields_with_java_deps.yaml b/sdks/python/apache_beam/yaml/examples/transforms/elementwise/map_to_fields_with_java_deps.yaml new file mode 100644 index 000000000000..e32ac1e7b67c --- /dev/null +++ b/sdks/python/apache_beam/yaml/examples/transforms/elementwise/map_to_fields_with_java_deps.yaml @@ -0,0 +1,57 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - {sdk: MapReduce, year: 2004} + - {sdk: MillWheel, year: 2008} + - {sdk: Flume, year: 2010} + - {sdk: Dataflow, year: 2014} + - {sdk: Apache Beam, year: 2016} + - type: MapToFields + name: ToRoman + config: + language: java + fields: + tool_name: sdk + year: + callable: | + import org.apache.beam.sdk.values.Row; + import java.util.function.Function; + import com.github.chaosfirebolt.converter.RomanInteger; + + public class MyFunction implements Function { + public String apply(Row row) { + return RomanInteger.parse( + String.valueOf(row.getInt64("year"))).toString(); + } + } + dependencies: + - 'com.github.chaosfirebolt.converter:roman-numeral-converter:2.1.0' + - type: LogForTesting + +# Expected: +# Row(tool_name='MapReduce', year='MMIV') +# Row(tool_name='MillWheel', year='MMVIII') +# Row(tool_name='Flume', year='MMX') +# Row(tool_name='Dataflow', year='MMXIV') +# Row(tool_name='Apache Beam', year='MMXVI') diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py index e26d6140af23..9a5d848486ca 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -24,6 +24,7 @@ from apache_beam import Row from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_transform import YamlTransform @@ -59,6 +60,8 @@ def test_enrichment_with_bigquery(self): with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', FakeEnrichmentTransform(enrichment_handler=handler, handler_config=config)): + # Force a reload to respect our mock. + yaml_provider.standard_providers.cache_clear() input_pcoll = p | 'CreateInput' >> beam.Create(input_data) result = input_pcoll | YamlTransform( f''' diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 7ffee5c2039b..4f7133838794 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -678,7 +678,8 @@ def _PyJsMapToFields( fields: Mapping[str, Union[str, Mapping[str, str]]], append: Optional[bool] = False, drop: Optional[Iterable[str]] = None, - language: Optional[str] = None): + language: Optional[str] = None, + dependencies: Optional[Iterable[str]] = None): """Creates records with new fields defined in terms of the input fields. See more complete documentation on @@ -694,6 +695,8 @@ def _PyJsMapToFields( original record that should not be kept language: The language used to define (and execute) the expressions and/or callables in `fields`. Defaults to generic. + dependencies: An optional list of extra dependencies that are needed for + these UDFs. The interpretation of these strings is language-dependent. error_handling: Whether and where to output records that throw errors when the above expressions are evaluated. """ # pylint: disable=line-too-long diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 97a84c068a4a..eca6c5f46dd2 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -19,6 +19,7 @@ for where to find and how to invoke services that vend implementations of various PTransforms.""" +import abc import collections import functools import hashlib @@ -30,6 +31,7 @@ import shutil import subprocess import sys +import tempfile import urllib.parse import warnings from collections.abc import Callable @@ -39,6 +41,7 @@ from typing import Optional from typing import Union +import clonevirtualenv import docstring_parser import yaml @@ -78,15 +81,18 @@ def __bool__(self): return False -class Provider: +class Provider(abc.ABC): """Maps transform types names and args to concrete PTransform instances.""" + @abc.abstractmethod def available(self) -> Union[bool, NotAvailableWithReason]: """Returns whether this provider is available to use in this environment.""" raise NotImplementedError(type(self)) + @abc.abstractmethod def cache_artifacts(self) -> Optional[Iterable[str]]: raise NotImplementedError(type(self)) + @abc.abstractmethod def provided_transforms(self) -> Iterable[str]: """Returns a list of transform type names this provider can handle.""" raise NotImplementedError(type(self)) @@ -107,6 +113,7 @@ def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool: """ return not typ.startswith('Read') + @abc.abstractmethod def create_transform( self, typ: str, @@ -145,6 +152,18 @@ def _affinity(self, other: "Provider"): else: return 0 + @functools.cache # pylint: disable=method-cache-max-size-none + def with_extra_dependencies(self, dependencies: Iterable[str]): + result = self._with_extra_dependencies(dependencies) + if not hasattr(result, 'to_json'): + result.to_json = lambda: {'type': type(result).__name__} + return result + + def _with_extra_dependencies(self, dependencies: Iterable[str]): + raise ValueError( + 'This provider of type %s does not support additional dependencies.' % + type(self).__name__) + def as_provider(name, provider_or_constructor): if isinstance(provider_or_constructor, Provider): @@ -339,10 +358,13 @@ def cache_artifacts(self): class ExternalJavaProvider(ExternalProvider): - def __init__(self, urns, jar_provider): + def __init__(self, urns, jar_provider, classpath=None): super().__init__( - urns, lambda: external.JavaJarExpansionService(jar_provider())) + urns, + lambda: external.JavaJarExpansionService( + jar_provider(), classpath=classpath)) self._jar_provider = jar_provider + self._classpath = classpath def available(self): # pylint: disable=subprocess-run-check @@ -364,6 +386,15 @@ def try_decode(bs): def cache_artifacts(self): return [self._jar_provider()] + def _with_extra_dependencies(self, dependencies: Iterable[str]): + jars = sum(( + external.JavaJarExpansionService._expand_jars(dep) + for dep in dependencies), []) + return ExternalJavaProvider( + self._urns, + jar_provider=self._jar_provider, + classpath=(list(self._classpath or []) + list(jars))) + @ExternalProvider.register_provider_type('python') def python(urns, provider_base_path, packages=()): @@ -392,6 +423,8 @@ def is_path_or_urn(package): if is_path_or_urn(package) else package for package in packages ])) + self._packages = packages + def available(self): return True # If we're running this script, we have Python installed. @@ -418,6 +451,10 @@ def _affinity(self, other: "Provider"): else: return super()._affinity(other) + def _with_extra_dependencies(self, dependencies: Iterable[str]): + return ExternalPythonProvider( + self._urns, None, set(self._packages).union(set(dependencies))) + @ExternalProvider.register_provider_type('yaml') class YamlProvider(Provider): @@ -606,6 +643,18 @@ def requires_inputs(self, typ, args): else: return super().requires_inputs(typ, args) + def _with_extra_dependencies(self, dependencies): + external_provider = ExternalPythonProvider( # disable yapf + { + typ: 'apache_beam.yaml.yaml_provider.standard_inline_providers.' + + typ.replace('-', '_') + for typ in self._transform_factories.keys() + }, + '__inline__', + dependencies) + external_provider.to_json = self.to_json + return external_provider + class MetaInlineProvider(InlineProvider): def create_transform(self, type, args, yaml_create_transform): @@ -1017,6 +1066,11 @@ def create_transform( yaml_create_transform: Any) -> beam.PTransform: return self._transforms[typ](self._underlying_provider, **config) + def _with_extra_dependencies(self, dependencies: Iterable[str]): + return TranslatingProvider( + self._transforms, + self._underlying_provider._with_extra_dependencies(dependencies)) + def create_java_builtin_provider(): """Exposes built-in transforms from Java as well as Python to maximize @@ -1068,7 +1122,14 @@ class PypiExpansionService: """Expands transforms by fully qualified name in a virtual environment with the given dependencies. """ - VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/venvs") + if 'TOX_WORK_DIR' in os.environ: + VENV_CACHE = tempfile.mkdtemp( + prefix='test-venv-cache-', dir=os.environ['TOX_WORK_DIR']) + elif 'RUNNER_WORKDIR' in os.environ: + VENV_CACHE = tempfile.mkdtemp( + prefix='test-venv-cache-', dir=os.environ['RUNNER_WORKDIR']) + else: + VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/venvs") def __init__( self, packages: Iterable[str], base_python: str = sys.executable): @@ -1130,10 +1191,7 @@ def _create_venv_from_clone( if not os.path.exists(venv): try: clonable_venv = cls._create_venv_to_clone(base_python) - clonable_python = os.path.join(clonable_venv, 'bin', 'python') - subprocess.run( - [clonable_python, '-m', 'clonevirtualenv', clonable_venv, venv], - check=True) + clonevirtualenv.clone_virtualenv(clonable_venv, venv) venv_pip = os.path.join(venv, 'bin', 'pip') subprocess.run([venv_pip, 'install'] + packages, check=True) with open(venv + '-requirements.txt', 'w') as fout: @@ -1296,6 +1354,14 @@ def underlying_provider(self): def cache_artifacts(self): self._underlying_provider.cache_artifacts() + def _with_extra_dependencies(self, dependencies: Iterable[str]): + return RenamingProvider( + self._transforms, + None, + self._mappings, + self._underlying_provider._with_extra_dependencies(dependencies), + self._defaults) + def _as_list(func): @functools.wraps(func) @@ -1376,6 +1442,7 @@ def merge_providers(*provider_sets) -> Mapping[str, Iterable[Provider]]: return result +@functools.cache def standard_providers(): from apache_beam.yaml.yaml_combine import create_combine_providers from apache_beam.yaml.yaml_mapping import create_mapping_providers @@ -1402,3 +1469,19 @@ def _file_digest(fileobj, digest): hasher.update(data) data = fileobj.read(1 << 20) return hasher + + +class _InlineProviderNamespace: + """Gives fully qualified names to inline providers from standard_providers(). + + This is needed to upgrade InlineProvider to ExternalPythonProvider. + """ + def __getattr__(self, name): + typ = name.replace('_', '-') + for provider in standard_providers()[typ]: + if isinstance(provider, InlineProvider): + return provider._transform_factories[typ] + raise ValueError(f"No inline provider found for {name}") + + +standard_inline_providers = _InlineProviderNamespace() diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 66ce87c9bfe1..744cbe6a8925 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -368,6 +368,9 @@ def expand(pcolls): if pcoll in providers_by_input ] provider = self.best_provider(spec, input_providers) + extra_dependencies, spec = extract_extra_dependencies(spec) + if extra_dependencies: + provider = provider.with_extra_dependencies(frozenset(extra_dependencies)) config = SafeLineLoader.strip_metadata(spec.get('config', {})) if not isinstance(config, dict): @@ -708,6 +711,17 @@ def extract_name(spec): return '' +def extract_extra_dependencies(spec): + deps = spec.get('config', {}).get('dependencies', []) + if not deps: + return [], spec + if not isinstance(deps, list): + raise TypeError(f'Dependencies must be a list of strings, got {deps}') + return deps, dict( + spec, + config={k: v for k, v in spec['config'].items() if k != 'dependencies'}) + + def push_windowing_to_roots(spec): scope = LightweightScope(spec['transforms']) consumed_outputs_by_transform = collections.defaultdict(set) diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 9a45d287d49d..49d9ac368811 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -401,7 +401,8 @@ def get_portability_package_data(): 'docstring-parser>=0.15,<1.0', 'docutils>=0.18.1', 'pandas<2.3.0', - 'openai' + 'openai', + 'virtualenv-clone>=0.5,<1.0', ], 'test': [ 'docstring-parser>=0.15,<1.0', @@ -424,6 +425,7 @@ def get_portability_package_data(): 'testcontainers[mysql]>=3.0.3,<4.0.0', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', + 'virtualenv-clone>=0.5,<1.0', ], 'gcp': [ 'cachetools>=3.1.0,<6',