Skip to content

Commit 264bdcd

Browse files
committed
[SYSTEMDS-3902] Accelerated data transfer Python <--> JVM
Introduced a new data transfer mechanism on Unix systems using FIFO (named) pipes as a faster alternative to py4j-based communication. - Supports multiple value types (uint8, int32, fp32, fp64) for dense matrix exchange. - Adds experimental support for partitioned matrix transfer from Python to Java via multiple concurrent pipes (disabled by default due to limited performance improvement). - Significantly reduces overhead compared to py4j for large matrix transfers in supported scenarios Closes #2296.
1 parent 9984092 commit 264bdcd

File tree

12 files changed

+1161
-30
lines changed

12 files changed

+1161
-30
lines changed

.github/workflows/javaTests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ on:
2929
- '*.html'
3030
- 'src/main/python/**'
3131
- 'dev/**'
32+
- '.github/workflows/python.yml'
3233
branches:
3334
- main
3435
pull_request:
@@ -38,6 +39,7 @@ on:
3839
- '*.html'
3940
- 'src/main/python/**'
4041
- 'dev/**'
42+
- '.github/workflows/python.yml'
4143
branches:
4244
- main
4345

src/main/java/org/apache/sysds/api/PythonDMLScript.java

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,40 @@
2525
import org.apache.log4j.Logger;
2626
import org.apache.sysds.api.jmlc.Connection;
2727

28+
import org.apache.sysds.common.Types;
29+
import org.apache.sysds.runtime.DMLRuntimeException;
30+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
31+
import org.apache.sysds.runtime.util.CommonThreadPool;
32+
import org.apache.sysds.runtime.util.UnixPipeUtils;
2833
import py4j.DefaultGatewayServerListener;
2934
import py4j.GatewayServer;
3035
import py4j.Py4JNetworkException;
3136

37+
import java.io.BufferedInputStream;
38+
import java.io.BufferedOutputStream;
39+
import java.io.IOException;
40+
import java.util.ArrayList;
41+
import java.util.HashMap;
42+
import java.util.List;
43+
import java.util.concurrent.Callable;
44+
import java.util.concurrent.ExecutionException;
45+
import java.util.concurrent.ExecutorService;
46+
import java.util.concurrent.Future;
47+
3248

3349
public class PythonDMLScript {
3450

3551
private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName());
3652
final private Connection _connection;
3753
public static GatewayServer GwS;
3854

55+
private static String fromPythonBase = "py2java";
56+
private static String toPythonBase = "java2py";
57+
public HashMap<Integer, BufferedInputStream> fromPython = null;
58+
public HashMap<Integer, BufferedOutputStream> toPython = null;
59+
public String baseDir;
60+
private static int BATCH_SIZE = 32*1024;
61+
3962
/**
4063
* Entry point for Python API.
4164
*
@@ -78,6 +101,103 @@ public Connection getConnection() {
78101
return _connection;
79102
}
80103

104+
105+
public void openPipes(String path, int num) throws IOException {
106+
fromPython = new HashMap<>(num * 2);
107+
toPython = new HashMap<>(num * 2);
108+
baseDir = path;
109+
for (int i = 0; i < num; i++) {
110+
BufferedInputStream pipe_in = UnixPipeUtils.openInput(path + "/" + fromPythonBase + "-" + i, i);
111+
LOG.debug("PY2JAVA pipe "+i+" is ready!");
112+
fromPython.put(i, pipe_in);
113+
114+
BufferedOutputStream pipe_out = UnixPipeUtils.openOutput(path + "/" + toPythonBase + "-" + i, i);
115+
toPython.put(i, pipe_out);
116+
}
117+
}
118+
119+
public MatrixBlock startReadingMbFromPipe(int id, int rlen, int clen, Types.ValueType type) throws IOException {
120+
long limit = (long) rlen * clen;
121+
LOG.debug("trying to read matrix from "+id+" with "+rlen+" rows and "+clen+" columns. Total size: "+limit);
122+
if(limit > Integer.MAX_VALUE)
123+
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
124+
" cannot be converted to MatrixBlock");
125+
MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
126+
if(fromPython != null){
127+
BufferedInputStream pipe = fromPython.get(id);
128+
double[] denseBlock = new double[(int) limit];
129+
UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, (int) limit, type, denseBlock, 0);
130+
mb.init(denseBlock, rlen, clen);
131+
} else {
132+
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
133+
}
134+
mb.recomputeNonZeros();
135+
mb.examSparsity();
136+
LOG.debug("Reading from Python finished");
137+
return mb;
138+
}
139+
140+
public MatrixBlock startReadingMbFromPipes(int[] blockSizes, int rlen, int clen, Types.ValueType type) throws ExecutionException, InterruptedException {
141+
long limit = (long) rlen * clen;
142+
if(limit > Integer.MAX_VALUE)
143+
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
144+
" cannot be converted to MatrixBlock");
145+
MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
146+
if(fromPython != null){
147+
ExecutorService pool = CommonThreadPool.get();
148+
double[] denseBlock = new double[(int) limit];
149+
int offsetOut = 0;
150+
List<Future<Void>> futures = new ArrayList<>();
151+
for (int i = 0; i < blockSizes.length; i++) {
152+
BufferedInputStream pipe = fromPython.get(i);
153+
int id = i, blockSize = blockSizes[i], _offsetOut = offsetOut;
154+
Callable<Void> task = () -> {
155+
UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, blockSize, type, denseBlock, _offsetOut);
156+
return null;
157+
};
158+
159+
futures.add(pool.submit(task));
160+
offsetOut += blockSize;
161+
}
162+
// Wait for all tasks and propagate exceptions
163+
for (Future<Void> f : futures) {
164+
f.get();
165+
}
166+
167+
mb.init(denseBlock, rlen, clen);
168+
} else {
169+
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
170+
}
171+
mb.recomputeNonZeros();
172+
mb.examSparsity();
173+
return mb;
174+
}
175+
176+
public void startWritingMbToPipe(int id, MatrixBlock mb) throws IOException {
177+
if (toPython != null) {
178+
int rlen = mb.getNumRows();
179+
int clen = mb.getNumColumns();
180+
int numElem = rlen * clen;
181+
LOG.debug("Trying to write matrix ["+baseDir + "-"+ id+"] with "+rlen+" rows and "+clen+" columns. Total size: "+numElem*8);
182+
183+
BufferedOutputStream out = toPython.get(id);
184+
long bytes = UnixPipeUtils.writeNumpyArrayInBatches(out, id, BATCH_SIZE, numElem, Types.ValueType.FP64, mb);
185+
186+
LOG.debug("Writing of " + bytes +" Bytes to Python ["+baseDir + "-"+ id+"] finished");
187+
} else {
188+
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
189+
}
190+
}
191+
192+
public void closePipes() throws IOException {
193+
LOG.debug("Closing all pipes in Java");
194+
for (BufferedInputStream pipe : fromPython.values())
195+
pipe.close();
196+
for (BufferedOutputStream pipe : toPython.values())
197+
pipe.close();
198+
LOG.debug("Closed all pipes in Java");
199+
}
200+
81201
protected static class DMLGateWayListener extends DefaultGatewayServerListener {
82202
private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName());
83203

0 commit comments

Comments
 (0)