1919#
2020# -------------------------------------------------------------
2121
22- from typing import Any , Collection , KeysView , Tuple , Union , Optional , Dict , TYPE_CHECKING , List
22+ from typing import (TYPE_CHECKING , Any , Collection , Dict , KeysView , List ,
23+ Optional , Tuple , Union )
2324
25+ from py4j .protocol import Py4JNetworkError
2426from py4j .java_collections import JavaArray
25- from py4j .java_gateway import JavaObject , JavaGateway
26-
27+ from py4j .java_gateway import JavaGateway , JavaObject
2728from systemds .script_building .dag import DAGNode , OutputType
2829from systemds .utils .consts import VALID_INPUT_TYPES
2930
@@ -79,9 +80,14 @@ def execute(self) -> JavaObject:
7980 self .__prepare_script ()
8081 ret = self .prepared_script .executeScript ()
8182 return ret
83+ except Py4JNetworkError :
84+ exception_str = "Py4JNetworkError: no connection to JVM, most likely due to previous crash or closed JVM from calls to close()"
85+ trace_back_limit = 0
8286 except Exception as e :
83- self .sds_context .exception_and_close (e )
84- return None
87+ exception_str = str (e )
88+ trace_back_limit = None
89+ self .sds_context .exception_and_close (exception_str , trace_back_limit )
90+
8591
8692 def execute_with_lineage (self ) -> Tuple [JavaObject , str ]:
8793 """If not already created, create a preparedScript from our DMLCode, pass python local data to our prepared
@@ -104,9 +110,13 @@ def execute_with_lineage(self) -> Tuple[JavaObject, str]:
104110 traces .append (self .prepared_script .getLineageTrace (output ))
105111 return ret , traces
106112
113+ except Py4JNetworkError :
114+ exception_str = "Py4JNetworkError: no connection to JVM, most likely due to previous crash or closed JVM from calls to close()"
115+ trace_back_limit = 0
107116 except Exception as e :
108- self .sds_context .exception_and_close (e )
109- return None , None
117+ exception_str = str (e )
118+ trace_back_limit = None
119+ self .sds_context .exception_and_close (exception_str , trace_back_limit )
110120
111121 def __prepare_script (self ):
112122 gateway = self .sds_context .java_gateway
@@ -190,15 +200,13 @@ def _dfs_dag_nodes(self, dag_node: VALID_INPUT_TYPES) -> str:
190200 # for each node do the dfs operation and save the variable names in `input_var_names`
191201 # get variable names of unnamed parameters
192202
193- unnamed_input_vars = [self ._dfs_dag_nodes (
194- input_node ) for input_node in dag_node .unnamed_input_nodes ]
203+ unnamed_input_vars = []
204+ for un_node in dag_node .unnamed_input_nodes :
205+ unnamed_input_vars .append (self ._dfs_dag_nodes (un_node ))
195206
196207 named_input_vars = {}
197208 for name , input_node in dag_node .named_input_nodes .items ():
198209 named_input_vars [name ] = self ._dfs_dag_nodes (input_node )
199- if isinstance (input_node , DAGNode ) and input_node ._output_type == OutputType .LIST :
200- dag_node .dml_name = named_input_vars [name ] + name
201- return dag_node .dml_name
202210
203211 # check if the node gets a name after multireturns
204212 # If it has, great, return that name
@@ -212,6 +220,7 @@ def _dfs_dag_nodes(self, dag_node: VALID_INPUT_TYPES) -> str:
212220
213221 code_line = dag_node .code_line (
214222 dag_node .dml_name , unnamed_input_vars , named_input_vars )
223+
215224 self .add_code (code_line )
216225 return dag_node .dml_name
217226
0 commit comments