Skip to content

Commit 9b0e68e

Browse files
committed
[MINOR] Fix in Python API GatewayServerListener
1 parent cfbe190 commit 9b0e68e

File tree

2 files changed

+63
-29
lines changed

2 files changed

+63
-29
lines changed

src/main/java/org/apache/sysds/api/PythonDMLScript.java

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,20 @@
2121

2222
import org.apache.commons.logging.Log;
2323
import org.apache.commons.logging.LogFactory;
24+
import org.apache.log4j.Level;
25+
import org.apache.log4j.Logger;
2426
import org.apache.sysds.api.jmlc.Connection;
2527

28+
import py4j.DefaultGatewayServerListener;
2629
import py4j.GatewayServer;
27-
import py4j.GatewayServerListener;
2830
import py4j.Py4JNetworkException;
29-
import py4j.Py4JServerConnection;
31+
3032

3133
public class PythonDMLScript {
3234

3335
private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName());
3436
final private Connection _connection;
37+
public static GatewayServer GwS;
3538

3639
/**
3740
* Entry point for Python API.
@@ -42,7 +45,7 @@ public class PythonDMLScript {
4245
public static void main(String[] args) throws Exception {
4346
final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args);
4447
DMLScript.loadConfiguration(dmlOptions.configFile);
45-
final GatewayServer GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort);
48+
GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort);
4649
GwS.addListener(new DMLGateWayListener());
4750
try {
4851
GwS.start();
@@ -67,38 +70,20 @@ private PythonDMLScript() {
6770
_connection = new Connection();
6871
}
6972

73+
public static void setDMLGateWayListenerLoggerLevel(Level l){
74+
Logger.getLogger(DMLGateWayListener.class).setLevel(l);
75+
}
76+
7077
public Connection getConnection() {
7178
return _connection;
7279
}
7380

74-
protected static class DMLGateWayListener implements GatewayServerListener {
81+
protected static class DMLGateWayListener extends DefaultGatewayServerListener {
7582
private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName());
7683

77-
@Override
78-
public void connectionError(Exception e) {
79-
LOG.warn("Connection error: " + e.getMessage());
80-
System.exit(1);
81-
}
82-
83-
@Override
84-
public void connectionStarted(Py4JServerConnection gatewayConnection) {
85-
LOG.debug("Connection Started: " + gatewayConnection.toString());
86-
}
87-
88-
@Override
89-
public void connectionStopped(Py4JServerConnection gatewayConnection) {
90-
LOG.debug("Connection stopped: " + gatewayConnection.toString());
91-
}
92-
93-
@Override
94-
public void serverError(Exception e) {
95-
LOG.error("Server Error " + e.getMessage());
96-
}
97-
9884
@Override
9985
public void serverPostShutdown() {
10086
LOG.info("Shutdown done");
101-
System.exit(0);
10287
}
10388

10489
@Override
@@ -108,13 +93,12 @@ public void serverPreShutdown() {
10893

10994
@Override
11095
public void serverStarted() {
111-
LOG.info("GatewayServer Started");
96+
LOG.info("GatewayServer started");
11297
}
11398

11499
@Override
115100
public void serverStopped() {
116-
LOG.info("GatewayServer Stopped");
117-
System.exit(0);
101+
LOG.info("GatewayServer stopped");
118102
}
119103
}
120104

src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,47 @@
1919

2020
package org.apache.sysds.test.usertest.pythonapi;
2121

22+
import org.apache.log4j.Level;
23+
import org.apache.log4j.spi.LoggingEvent;
2224
import org.apache.sysds.api.PythonDMLScript;
25+
import org.apache.sysds.test.LoggingUtils;
26+
import org.junit.After;
27+
import org.junit.Assert;
28+
import org.junit.Before;
2329
import org.junit.Test;
2430

31+
import java.util.List;
32+
33+
2534
/** Simple tests to verify startup of Python Gateway server happens without crashes */
2635
public class StartupTest {
36+
private LoggingUtils.TestAppender appender;
37+
38+
@Before
39+
public void setUp() {
40+
appender = LoggingUtils.overwrite();
41+
PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL);
42+
}
43+
44+
@After
45+
public void tearDown() {
46+
LoggingUtils.reinsert(appender);
47+
}
48+
49+
private void assertLogMessages(String... expectedMessages) {
50+
List<LoggingEvent> log = LoggingUtils.reinsert(appender);
51+
log.stream().forEach(l -> System.out.println(l.getMessage()));
52+
Assert.assertEquals("Unexpected number of log messages", expectedMessages.length, log.size());
53+
54+
for (int i = 0; i < expectedMessages.length; i++) {
55+
// order does not matter
56+
boolean found = false;
57+
for (String message : expectedMessages) {
58+
found |= log.get(i).getMessage().toString().startsWith(message);
59+
}
60+
Assert.assertTrue("Unexpected log message: " + log.get(i).getMessage(),found);
61+
}
62+
}
2763

2864
@Test(expected = Exception.class)
2965
public void testStartupIncorrect_1() throws Exception {
@@ -50,4 +86,18 @@ public void testStartupIncorrect_5() throws Exception {
5086
// Number out of range
5187
PythonDMLScript.main(new String[] {"-python", "918757"});
5288
}
89+
90+
@Test
91+
public void testStartupCorrect() throws Exception {
92+
PythonDMLScript.main(new String[]{"-python", "4001"});
93+
Thread.sleep(200);
94+
PythonDMLScript.GwS.shutdown();
95+
Thread.sleep(200);
96+
assertLogMessages(
97+
"GatewayServer started",
98+
"Starting JVM shutdown",
99+
"Shutdown done",
100+
"GatewayServer stopped"
101+
);
102+
}
53103
}

0 commit comments

Comments
 (0)