Skip to content

Commit cdb24de

Browse files
committed
Allow dependency declarations for java.
1 parent eff7beb commit cdb24de

File tree

3 files changed

+77
-4
lines changed

3 files changed

+77
-4
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# coding=utf-8
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one or more
4+
# contributor license agreements. See the NOTICE file distributed with
5+
# this work for additional information regarding copyright ownership.
6+
# The ASF licenses this file to You under the Apache License, Version 2.0
7+
# (the "License"); you may not use this file except in compliance with
8+
# the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
pipeline:
20+
type: chain
21+
transforms:
22+
- type: Create
23+
config:
24+
elements:
25+
- {sdk: MapReduce, year: 2004}
26+
- {sdk: MillWheel, year: 2008}
27+
- {sdk: Flume, year: 2010}
28+
- {sdk: Dataflow, year: 2014}
29+
- {sdk: Apache Beam, year: 2016}
30+
- type: MapToFields
31+
name: ToRoman
32+
config:
33+
language: java
34+
fields:
35+
tool_name: sdk
36+
year:
37+
callable: |
38+
import org.apache.beam.sdk.values.Row;
39+
import java.util.function.Function;
40+
import com.github.chaosfirebolt.converter.RomanInteger;
41+
42+
public class MyFunction implements Function<Row, String> {
43+
public String apply(Row row) {
44+
return RomanInteger.parse(
45+
String.valueOf(row.getInt64("year"))).toString();
46+
}
47+
}
48+
dependencies:
49+
- 'com.github.chaosfirebolt.converter:roman-numeral-converter:2.1.0'
50+
- type: LogForTesting
51+
52+
# Expected:
53+
# Row(tool_name='MapReduce', year='MMIV')
54+
# Row(tool_name='MillWheel', year='MMVIII')
55+
# Row(tool_name='Flume', year='MMX')
56+
# Row(tool_name='Dataflow', year='MMXIV')
57+
# Row(tool_name='Apache Beam', year='MMXVI')

sdks/python/apache_beam/yaml/yaml_provider.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,13 @@ def cache_artifacts(self):
340340

341341

342342
class ExternalJavaProvider(ExternalProvider):
343-
def __init__(self, urns, jar_provider):
343+
def __init__(self, urns, jar_provider, classpath=None):
344344
super().__init__(
345-
urns, lambda: external.JavaJarExpansionService(jar_provider()))
345+
urns,
346+
lambda: external.JavaJarExpansionService(
347+
jar_provider(), classpath=classpath))
346348
self._jar_provider = jar_provider
349+
self._classpath = classpath
347350

348351
def available(self):
349352
# pylint: disable=subprocess-run-check
@@ -353,6 +356,15 @@ def available(self):
353356
def cache_artifacts(self):
354357
return [self._jar_provider()]
355358

359+
def _with_extra_dependencies(self, dependencies: Iterable[str]):
360+
jars = sum((
361+
external.JavaJarExpansionService._expand_jars(dep)
362+
for dep in dependencies), [])
363+
return ExternalJavaProvider(
364+
self._urns,
365+
jar_provider=self._jar_provider,
366+
classpath=(list(self._classpath or []) + list(jars)))
367+
356368

357369
@ExternalProvider.register_provider_type('python')
358370
def python(urns, provider_base_path, packages=()):

sdks/python/apache_beam/yaml/yaml_transform.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def expand(pcolls):
348348
if pcoll in providers_by_input
349349
]
350350
provider = self.best_provider(spec, input_providers)
351-
extra_dependencies = extract_extra_dependencies(spec)
351+
extra_dependencies, spec = extract_extra_dependencies(spec)
352352
if extra_dependencies:
353353
provider = provider.with_extra_dependencies(frozenset(extra_dependencies))
354354

@@ -687,9 +687,13 @@ def extract_name(spec):
687687

688688
def extract_extra_dependencies(spec):
689689
deps = spec.get('config', {}).get('dependencies', [])
690+
if not deps:
691+
return [], spec
690692
if not isinstance(deps, list):
691693
raise TypeErrorError(f'Dependencies must be a list of strings, got {deps}')
692-
return deps
694+
return deps, dict(
695+
spec,
696+
config={k: v for k, v in spec['config'].items() if k != 'dependencies'})
693697

694698

695699
def push_windowing_to_roots(spec):

0 commit comments

Comments
 (0)