Skip to content

Commit 44c1b1d

Browse files
authored
Merge pull request #521 from juripetersen/fix-python-join
Fix join key extraction for python join operator
2 parents f9a677a + e27e518 commit 44c1b1d

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

python/src/pywy/core/serializer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ def serialize(self, operator):
5353

5454
json_operator["data"] = {}
5555

56-
if hasattr(operator, "input_type"):
57-
if operator.input_type is not None:
58-
json_operator["data"]["inputType"] = ndim_from_type(operator.input_type).to_json()
59-
if hasattr(operator, "output_type"):
60-
if operator.output_type is not None:
61-
json_operator["data"]["outputType"] = ndim_from_type(operator.output_type).to_json()
56+
if operator.json_name != "join":
57+
if hasattr(operator, "input_type"):
58+
if operator.input_type is not None:
59+
json_operator["data"]["inputType"] = ndim_from_type(operator.input_type).to_json()
60+
if hasattr(operator, "output_type"):
61+
if operator.output_type is not None:
62+
json_operator["data"]["outputType"] = ndim_from_type(operator.output_type).to_json()
6263

6364
if operator.json_name == "filter":
6465
json_operator["data"]["udf"] = base64.b64encode(cloudpickle.dumps(operator.use_predicate)).decode('utf-8')

python/src/pywy/operators/binary.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,13 @@ def __init__(
6363
output_type: GenericTco
6464
):
6565
super().__init__("Join", input_type, output_type)
66-
self.this_key_function = lambda g: this_key_function(next(g))
66+
self.this_key_function = lambda g: this_key_function(ast.literal_eval(next(g)))
6767
self.that = that
68-
self.that_key_function = lambda g: that_key_function(next(g))
68+
self.that_key_function = lambda g: that_key_function(ast.literal_eval(next(g)))
6969
self.json_name = "join"
7070

7171

72+
7273
class DLTrainingOperator(BinaryToUnaryOperator):
7374
model: Model
7475
option: Option

python/src/pywy/tests/join_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
#from typing import Tuple, Callable, Iterable
20+
from pywy.dataquanta import WayangContext
21+
from unittest.mock import Mock
22+
from pywy.platforms.java import JavaPlugin
23+
from pywy.platforms.spark import SparkPlugin
24+
25+
class TestJoin(unittest.TestCase):
26+
def test_to_json(self):
27+
ctx = WayangContext() \
28+
.register({JavaPlugin, SparkPlugin})
29+
30+
left = ctx.textfile("file:///var/www/html/data/left.csv").map(lambda x: tuple(x.split(",")), (int, str), (int, str))
31+
right = ctx.textfile("file:///var/www/html/data/right.csv").map(lambda x: tuple(x.split(",")), (int, str), (int, str))
32+
33+
def join_key(item: (int, str)) -> (int):
34+
print(f"join item {item}")
35+
print(f"key: {item[0]}")
36+
37+
return item[0]
38+
39+
join = left.join(join_key, right, join_key) \
40+
.store_textfile("file:///var/www/html/data/join-out-python.txt")
41+
42+
self.assertEqual(True, True)
43+
44+
if __name__ == "__main__":
45+
unittest.main()

wayang-api/wayang-api-python/src/main/java/org/apache/wayang/api/python/function/WrappedTransformationDescriptor.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ public WrappedTransformationDescriptor(
3939
input.add(item);
4040
final PythonWorkerManager<Input, Output> manager = new PythonWorkerManager<>(serializedUDF, input);
4141
final Iterable<Output> output = manager.execute();
42-
return output.iterator().next();
42+
43+
if (output.iterator().hasNext()) {
44+
return output.iterator().next();
45+
}
46+
47+
return null;
4348
},
4449
inputTypeClass,
4550
outputTypeClass

0 commit comments

Comments
 (0)