Skip to content

Commit b7c91e6

Browse files
committed
[SYSTEMDS-3223] Python functions with list arguments
This commit fixes a bug where the sourced functions would not correctly build the scripts in cases where a list is an input to the function defined in a sourced script.
1 parent 788637a commit b7c91e6

File tree

8 files changed

+131
-25
lines changed

8 files changed

+131
-25
lines changed

src/main/python/systemds/context/systemds_context.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def __init__(self, port: int = -1):
6464
if process.poll() is None:
6565
self.__start_gateway(actual_port)
6666
else:
67-
self.exception_and_close(
68-
"Java process stopped before gateway could connect")
67+
self.exception_and_close("Java process stopped before gateway could connect")
6968

7069
def get_stdout(self, lines: int = -1):
7170
"""Getter for the stdout of the java subprocess
@@ -89,7 +88,7 @@ def get_stderr(self, lines: int = -1):
8988
else:
9089
return [self.__stderr.get() for x in range(lines)]
9190

92-
def exception_and_close(self, exception_str: str, trace_back_limit: int = None):
91+
def exception_and_close(self, exception, trace_back_limit: int = None):
9392
"""
9493
Method for printing exception, printing stdout and error, while also closing the context correctly.
9594
@@ -104,7 +103,7 @@ def exception_and_close(self, exception_str: str, trace_back_limit: int = None):
104103
if stdErr:
105104
message += "standard error :\n" + "\n".join(stdErr)
106105
message += "\n\n"
107-
message += exception_str
106+
message += str(exception)
108107
sys.tracebacklimit = trace_back_limit
109108
self.close()
110109
raise RuntimeError(message)

src/main/python/systemds/operator/nodes/list.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,10 @@ def pass_python_data_to_prepared_script(self, sds, var_name: str, prepared_scrip
7676

7777
def code_line(self, var_name: str, unnamed_input_vars: Sequence[str],
7878
named_input_vars: Dict[str, str]) -> str:
79-
inputs_comma_sep = create_params_string(
80-
unnamed_input_vars, named_input_vars)
81-
return f'{var_name}={self.operation}({inputs_comma_sep});'
79+
code_line = super().code_line(var_name, unnamed_input_vars, named_input_vars)
80+
return code_line
8281

83-
def compute(self, verbose: bool = False, lineage: bool = False) -> Union[np.array]:
82+
def compute(self, verbose: bool = False, lineage: bool = False) -> np.array:
8483
return super().compute(verbose, lineage)
8584

8685
def __str__(self):

src/main/python/systemds/operator/nodes/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def code_line(self, var_name: str, unnamed_input_vars: Sequence[str],
5757
else:
5858
return super().code_line(var_name, unnamed_input_vars, named_input_vars)
5959

60-
def compute(self, verbose: bool = False, lineage: bool = False) -> Union[np.array]:
60+
def compute(self, verbose: bool = False, lineage: bool = False):
6161
return super().compute(verbose, lineage)
6262

6363
def _parse_output_result_variables(self, result_variables):

src/main/python/systemds/script_building/script.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
#
2020
# -------------------------------------------------------------
2121

22-
from typing import Any, Collection, KeysView, Tuple, Union, Optional, Dict, TYPE_CHECKING, List
22+
from typing import (TYPE_CHECKING, Any, Collection, Dict, KeysView, List,
23+
Optional, Tuple, Union)
2324

25+
from py4j.protocol import Py4JNetworkError
2426
from py4j.java_collections import JavaArray
25-
from py4j.java_gateway import JavaObject, JavaGateway
26-
27+
from py4j.java_gateway import JavaGateway, JavaObject
2728
from systemds.script_building.dag import DAGNode, OutputType
2829
from systemds.utils.consts import VALID_INPUT_TYPES
2930

@@ -79,9 +80,14 @@ def execute(self) -> JavaObject:
7980
self.__prepare_script()
8081
ret = self.prepared_script.executeScript()
8182
return ret
83+
except Py4JNetworkError:
84+
exception_str = "Py4JNetworkError: no connection to JVM, most likely due to previous crash or closed JVM from calls to close()"
85+
trace_back_limit = 0
8286
except Exception as e:
83-
self.sds_context.exception_and_close(e)
84-
return None
87+
exception_str = str(e)
88+
trace_back_limit = None
89+
self.sds_context.exception_and_close(exception_str, trace_back_limit)
90+
8591

8692
def execute_with_lineage(self) -> Tuple[JavaObject, str]:
8793
"""If not already created, create a preparedScript from our DMLCode, pass python local data to our prepared
@@ -104,9 +110,13 @@ def execute_with_lineage(self) -> Tuple[JavaObject, str]:
104110
traces.append(self.prepared_script.getLineageTrace(output))
105111
return ret, traces
106112

113+
except Py4JNetworkError:
114+
exception_str = "Py4JNetworkError: no connection to JVM, most likely due to previous crash or closed JVM from calls to close()"
115+
trace_back_limit = 0
107116
except Exception as e:
108-
self.sds_context.exception_and_close(e)
109-
return None, None
117+
exception_str = str(e)
118+
trace_back_limit = None
119+
self.sds_context.exception_and_close(exception_str, trace_back_limit)
110120

111121
def __prepare_script(self):
112122
gateway = self.sds_context.java_gateway
@@ -190,15 +200,13 @@ def _dfs_dag_nodes(self, dag_node: VALID_INPUT_TYPES) -> str:
190200
# for each node do the dfs operation and save the variable names in `input_var_names`
191201
# get variable names of unnamed parameters
192202

193-
unnamed_input_vars = [self._dfs_dag_nodes(
194-
input_node) for input_node in dag_node.unnamed_input_nodes]
203+
unnamed_input_vars = []
204+
for un_node in dag_node.unnamed_input_nodes:
205+
unnamed_input_vars.append(self._dfs_dag_nodes(un_node))
195206

196207
named_input_vars = {}
197208
for name, input_node in dag_node.named_input_nodes.items():
198209
named_input_vars[name] = self._dfs_dag_nodes(input_node)
199-
if isinstance(input_node, DAGNode) and input_node._output_type == OutputType.LIST:
200-
dag_node.dml_name = named_input_vars[name] + name
201-
return dag_node.dml_name
202210

203211
# check if the node gets a name after multireturns
204212
# If it has, great, return that name
@@ -212,6 +220,7 @@ def _dfs_dag_nodes(self, dag_node: VALID_INPUT_TYPES) -> str:
212220

213221
code_line = dag_node.code_line(
214222
dag_node.dml_name, unnamed_input_vars, named_input_vars)
223+
215224
self.add_code(code_line)
216225
return dag_node.dml_name
217226

src/main/python/tests/matrix/test_print.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import unittest
2323

2424
import numpy as np
25+
from time import sleep
2526
from systemds.context import SystemDSContext
2627

2728

@@ -32,6 +33,9 @@ class TestPrint(unittest.TestCase):
3233
@classmethod
3334
def setUpClass(cls):
3435
cls.sds = SystemDSContext()
36+
sleep(1.0)
37+
cls.sds.get_stdout()
38+
cls.sds.get_stdout()
3539

3640
@classmethod
3741
def tearDownClass(cls):

src/main/python/tests/script/test_dml_script.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# -------------------------------------------------------------
2121

2222
import unittest
23-
import time
23+
from time import sleep
2424

2525
from systemds.context import SystemDSContext
2626
from systemds.script_building import DMLScript
@@ -35,6 +35,9 @@ class Test_DMLScript(unittest.TestCase):
3535
@classmethod
3636
def setUpClass(cls):
3737
cls.sds = SystemDSContext()
38+
sleep(1.0)
39+
cls.sds.get_stdout()
40+
cls.sds.get_stdout()
3841

3942
@classmethod
4043
def tearDownClass(cls):
@@ -44,7 +47,7 @@ def test_simple_print_1(self):
4447
script = DMLScript(self.sds)
4548
script.add_code('print("Hello")')
4649
script.execute()
47-
time.sleep(0.5)
50+
sleep(0.5)
4851
stdout = self.sds.get_stdout(100)
4952
self.assertListEqual(["Hello"], stdout)
5053

@@ -54,7 +57,7 @@ def test_simple_print_2(self):
5457
script.add_code('print("World")')
5558
script.add_code('print("!")')
5659
script.execute()
57-
time.sleep(0.5)
60+
sleep(0.5)
5861
stdout = self.sds.get_stdout(100)
5962
self.assertListEqual(['Hello', 'World', '!'], stdout)
6063

@@ -65,7 +68,7 @@ def test_multiple_executions_1(self):
6568
scr_a.add_code('y = x + 1')
6669
scr_a.add_code('print(y)')
6770
scr_a.execute()
68-
time.sleep(0.5)
71+
sleep(0.5)
6972
stdout = self.sds.get_stdout(100)
7073
self.assertEqual("4", stdout[0])
7174
self.assertEqual("5", stdout[1])
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
func = function(list[unknown] a) return (matrix[double] b){
23+
b = as.matrix(a[1])
24+
}
25+
26+
func2 = function(list[unknown] a) return (matrix[double] b, matrix[double] c){
27+
b = as.matrix(a[1])
28+
c = as.matrix(a[2])
29+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
import unittest
23+
24+
import numpy as np
25+
from systemds.context import SystemDSContext
26+
from systemds.operator.algorithm.builtin.scale import scale
27+
28+
29+
class TestSource_01(unittest.TestCase):
30+
31+
sds: SystemDSContext = None
32+
source_path: str = "./tests/source/source_with_list_input.dml"
33+
34+
@classmethod
35+
def setUpClass(cls):
36+
cls.sds = SystemDSContext()
37+
38+
@classmethod
39+
def tearDownClass(cls):
40+
cls.sds.close()
41+
42+
def test_single_return(self):
43+
arr = self.sds.array(self.sds.full((10, 10), 4))
44+
c = self.sds.source(self.source_path, "test").func(arr)
45+
res = c.sum().compute()
46+
self.assertTrue(res == 10*10*4)
47+
48+
def test_input_multireturn(self):
49+
m = self.sds.full((10, 10), 2)
50+
[a, b, c] = scale(m, True, True)
51+
arr = self.sds.array(a, b, c)
52+
c = self.sds.source(self.source_path, "test").func(arr)
53+
res = c.sum().compute()
54+
self.assertTrue(res == 0)
55+
56+
# [SYSTEMDS-3224] https://issues.apache.org/jira/browse/SYSTEMDS-3224
57+
# def test_multi_return(self):
58+
# arr = self.sds.array(
59+
# self.sds.full((10, 10), 4),
60+
# self.sds.full((3, 3), 5))
61+
# [b, c] = self.sds.source(self.source_path, "test", True).func2(arr)
62+
# res = c.sum().compute()
63+
# self.assertTrue(res == 10*10*4)

0 commit comments

Comments
 (0)