Skip to content

Commit 0386fe1

Browse files
e-straussBaunsgaard
authored andcommitted
[MINOR] Fix in Python API GatewayServerListener
The current inner class DMLGateWayListener implements function from the GatewayServerListener interface, which are never invoked by the GatewayServer (since the GatewayServer, which also implements GatewayServerListener, does not implement these methods. Furthermore, DMLGateWayListener previously called, Sys.exit(), which I think is not correct, since it breaks the proper shutdown of the GatewayServer. Finally, this commit added a new unit case, which checks the functionality of the DMLGateWayListener. While merging, we verified that the additions did not contain any regressions in startup and shutdown of the Python API. Closes #2243
1 parent 9a09c45 commit 0386fe1

File tree

2 files changed

+101
-29
lines changed

2 files changed

+101
-29
lines changed

src/main/java/org/apache/sysds/api/PythonDMLScript.java

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,20 @@
2121

2222
import org.apache.commons.logging.Log;
2323
import org.apache.commons.logging.LogFactory;
24+
import org.apache.log4j.Level;
25+
import org.apache.log4j.Logger;
2426
import org.apache.sysds.api.jmlc.Connection;
2527

28+
import py4j.DefaultGatewayServerListener;
2629
import py4j.GatewayServer;
27-
import py4j.GatewayServerListener;
2830
import py4j.Py4JNetworkException;
29-
import py4j.Py4JServerConnection;
31+
3032

3133
public class PythonDMLScript {
3234

3335
private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName());
3436
final private Connection _connection;
37+
public static GatewayServer GwS;
3538

3639
/**
3740
* Entry point for Python API.
@@ -42,7 +45,7 @@ public class PythonDMLScript {
4245
public static void main(String[] args) throws Exception {
4346
final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args);
4447
DMLScript.loadConfiguration(dmlOptions.configFile);
45-
final GatewayServer GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort);
48+
GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort);
4649
GwS.addListener(new DMLGateWayListener());
4750
try {
4851
GwS.start();
@@ -67,38 +70,20 @@ private PythonDMLScript() {
6770
_connection = new Connection();
6871
}
6972

73+
public static void setDMLGateWayListenerLoggerLevel(Level l){
74+
Logger.getLogger(DMLGateWayListener.class).setLevel(l);
75+
}
76+
7077
public Connection getConnection() {
7178
return _connection;
7279
}
7380

74-
protected static class DMLGateWayListener implements GatewayServerListener {
81+
protected static class DMLGateWayListener extends DefaultGatewayServerListener {
7582
private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName());
7683

77-
@Override
78-
public void connectionError(Exception e) {
79-
LOG.warn("Connection error: " + e.getMessage());
80-
System.exit(1);
81-
}
82-
83-
@Override
84-
public void connectionStarted(Py4JServerConnection gatewayConnection) {
85-
LOG.debug("Connection Started: " + gatewayConnection.toString());
86-
}
87-
88-
@Override
89-
public void connectionStopped(Py4JServerConnection gatewayConnection) {
90-
LOG.debug("Connection stopped: " + gatewayConnection.toString());
91-
}
92-
93-
@Override
94-
public void serverError(Exception e) {
95-
LOG.error("Server Error " + e.getMessage());
96-
}
97-
9884
@Override
9985
public void serverPostShutdown() {
10086
LOG.info("Shutdown done");
101-
System.exit(0);
10287
}
10388

10489
@Override
@@ -108,13 +93,12 @@ public void serverPreShutdown() {
10893

10994
@Override
11095
public void serverStarted() {
111-
LOG.info("GatewayServer Started");
96+
LOG.info("GatewayServer started");
11297
}
11398

11499
@Override
115100
public void serverStopped() {
116-
LOG.info("GatewayServer Stopped");
117-
System.exit(0);
101+
LOG.info("GatewayServer stopped");
118102
}
119103
}
120104

src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,55 @@
1919

2020
package org.apache.sysds.test.usertest.pythonapi;
2121

22+
import org.apache.log4j.Level;
23+
import org.apache.log4j.Logger;
24+
import org.apache.log4j.spi.LoggingEvent;
2225
import org.apache.sysds.api.PythonDMLScript;
26+
import org.apache.sysds.test.LoggingUtils;
27+
import org.junit.After;
28+
import org.junit.Assert;
29+
import org.junit.Before;
2330
import org.junit.Test;
31+
import py4j.GatewayServer;
32+
33+
import java.security.Permission;
34+
import java.util.List;
35+
2436

2537
/** Simple tests to verify startup of Python Gateway server happens without crashes */
2638
public class StartupTest {
39+
private LoggingUtils.TestAppender appender;
40+
private SecurityManager sm;
41+
42+
@Before
43+
public void setUp() {
44+
appender = LoggingUtils.overwrite();
45+
sm = System.getSecurityManager();
46+
System.setSecurityManager(new NoExitSecurityManager());
47+
PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL);
48+
Logger.getLogger(PythonDMLScript.class.getName()).setLevel(Level.ALL);
49+
}
50+
51+
@After
52+
public void tearDown() {
53+
LoggingUtils.reinsert(appender);
54+
System.setSecurityManager(sm);
55+
}
56+
57+
private void assertLogMessages(String... expectedMessages) {
58+
List<LoggingEvent> log = LoggingUtils.reinsert(appender);
59+
log.stream().forEach(l -> System.out.println(l.getMessage()));
60+
Assert.assertEquals("Unexpected number of log messages", expectedMessages.length, log.size());
61+
62+
for (int i = 0; i < expectedMessages.length; i++) {
63+
// order does not matter
64+
boolean found = false;
65+
for (String message : expectedMessages) {
66+
found |= log.get(i).getMessage().toString().startsWith(message);
67+
}
68+
Assert.assertTrue("Unexpected log message: " + log.get(i).getMessage(),found);
69+
}
70+
}
2771

2872
@Test(expected = Exception.class)
2973
public void testStartupIncorrect_1() throws Exception {
@@ -50,4 +94,48 @@ public void testStartupIncorrect_5() throws Exception {
5094
// Number out of range
5195
PythonDMLScript.main(new String[] {"-python", "918757"});
5296
}
97+
98+
@Test
99+
public void testStartupIncorrect_6() throws Exception {
100+
GatewayServer gws1 = null;
101+
try {
102+
PythonDMLScript.main(new String[]{"-python", "4001"});
103+
gws1 = PythonDMLScript.GwS;
104+
Thread.sleep(200);
105+
PythonDMLScript.main(new String[]{"-python", "4001"});
106+
Thread.sleep(200);
107+
} catch (SecurityException e) {
108+
assertLogMessages(
109+
"GatewayServer started",
110+
"failed startup"
111+
);
112+
gws1.shutdown();
113+
}
114+
}
115+
116+
@Test
117+
public void testStartupCorrect() throws Exception {
118+
PythonDMLScript.main(new String[]{"-python", "4002"});
119+
Thread.sleep(200);
120+
PythonDMLScript script = (PythonDMLScript) PythonDMLScript.GwS.getGateway().getEntryPoint();
121+
script.getConnection();
122+
PythonDMLScript.GwS.shutdown();
123+
Thread.sleep(200);
124+
assertLogMessages(
125+
"GatewayServer started",
126+
"Starting JVM shutdown",
127+
"Shutdown done",
128+
"GatewayServer stopped"
129+
);
130+
}
131+
132+
class NoExitSecurityManager extends SecurityManager {
133+
@Override
134+
public void checkPermission(Permission perm) { }
135+
136+
@Override
137+
public void checkExit(int status) {
138+
throw new SecurityException("Intercepted exit()");
139+
}
140+
}
53141
}

0 commit comments

Comments
 (0)