Skip to content

Commit b854305

Browse files
Merge branch 'main' into einsum
2 parents ae4bb82 + c72461f commit b854305

File tree

85 files changed

+3292
-359
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+3292
-359
lines changed

.github/workflows/javaTests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ on:
2929
- '*.html'
3030
- 'src/main/python/**'
3131
- 'dev/**'
32+
- '.github/workflows/python.yml'
3233
branches:
3334
- main
3435
pull_request:
@@ -38,6 +39,7 @@ on:
3839
- '*.html'
3940
- 'src/main/python/**'
4041
- 'dev/**'
42+
- '.github/workflows/python.yml'
4143
branches:
4244
- main
4345

.github/workflows/python.yml

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,17 @@ on:
4545
jobs:
4646
test:
4747
runs-on: ${{ matrix.os }}
48+
timeout-minutes: 60
4849
strategy:
4950
fail-fast: false
5051
matrix:
5152
python-version: [3.8]
5253
os: [ubuntu-24.04]
5354
java: ['17']
5455
javadist: ['adopt']
56+
test_mode: [env, noenv, federated]
5557

56-
name: ${{ matrix.os }} Java ${{ matrix.java }} ${{ matrix.javadist }} Python ${{ matrix.python-version }}
58+
name: ${{ matrix.os }} Java ${{ matrix.java }} ${{ matrix.javadist }} Python ${{ matrix.python-version }}/ ${{ matrix.test_mode}}
5759
steps:
5860
- name: Checkout Repository
5961
uses: actions/checkout@v4
@@ -115,7 +117,6 @@ jobs:
115117
librosa \
116118
h5py \
117119
gensim \
118-
black \
119120
opt-einsum \
120121
nltk
121122
@@ -124,38 +125,29 @@ jobs:
124125
cd src/main/python
125126
python create_python_dist.py
126127
127-
- name: Run all python tests
128+
- name: Run tests with env
129+
if: ${{ matrix.test_mode == 'env' }}
128130
run: |
129131
export SYSTEMDS_ROOT=$(pwd)
130132
export PATH=$SYSTEMDS_ROOT/bin:$PATH
131133
export SYSDS_QUIET=1
132134
export LOG4JPROP=$SYSTEMDS_ROOT/src/test/resources/log4j.properties
133135
cd src/main/python
134136
unittest-parallel -t . -s tests -v
135-
# python -m unittest discover -s tests -p 'test_*.py'
136-
echo "Exit Status: " $?
137137
138-
- name: Run all python tests no environment
138+
- name: Run tests no env
139+
if: ${{ matrix.test_mode == 'noenv' }}
139140
run: |
140141
export LOG4JPROP=$(pwd)/src/test/resources/log4j.properties
141142
cd src/main/python
142143
unittest-parallel -t . -s tests -v
143-
# python -m unittest discover -s tests -p 'test_*.py'
144-
echo "Exit Status: " $?
145144
146145
- name: Run Federated Python Tests
146+
if: ${{ matrix.test_mode == 'federated' }}
147147
run: |
148148
export SYSTEMDS_ROOT=$(pwd)
149149
export PATH=$SYSTEMDS_ROOT/bin:$PATH
150150
cd src/main/python
151151
./tests/federated/runFedTest.sh
152152
153-
- name: Check formatting according to Black (src/main/python/systemds)
154-
run: |
155-
black --check --exclude operator/algorithm src/main/python/systemds
156-
157-
- name: Check formatting according to Black (src/main/python/tests)
158-
run: |
159-
black --check src/main/python/tests
160-
161153
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
name: Python Black Format Check
23+
24+
on:
25+
push:
26+
paths:
27+
- 'src/main/python/**'
28+
branches:
29+
- main
30+
pull_request:
31+
paths:
32+
- 'src/main/python/**'
33+
branches:
34+
- main
35+
36+
jobs:
37+
black:
38+
runs-on: ubuntu-latest
39+
steps:
40+
- name: Checkout Repository
41+
uses: actions/checkout@v4
42+
43+
- name: Setup Python
44+
uses: actions/setup-python@v5
45+
with:
46+
python-version: '3.x'
47+
48+
- name: Install Black
49+
run: |
50+
pip install --upgrade pip
51+
pip install black
52+
53+
- name: Run Black Check
54+
run: |
55+
black --check --exclude operator/algorithm \
56+
src/main/python/systemds src/main/python/tests

docs/site/dml-language-reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ quantile () | The p-quantile for a random variable X is the value x such that Pr
702702
quantile () | Returns a column matrix with list of all quantiles requested in P. | Input: (X &lt;(n x 1) matrix&gt;, [W &lt;(n x 1) matrix&gt;),] P &lt;(q x 1) matrix&gt;) <br/> Output: matrix | quantile(X, P) <br/> quantile(X, W, P)
703703
median() | Computes the median in a given column matrix of values | Input: (X &lt;(n x 1) matrix&gt;, [W &lt;(n x 1) matrix&gt;),]) <br/> Output: &lt;scalar&gt; | median(X) <br/> median(X,W)
704704
rowSums() <br/> rowMeans() <br/> rowVars() <br/> rowSds() <br/> rowMaxs() <br/> rowMins() | Row-wise computations -- for each row, compute the sum/mean/variance/stdDev/max/min of cell value | Input: matrix <br/> Output: (n x 1) matrix | rowSums(X) <br/> rowMeans(X) <br/> rowVars(X) <br/> rowSds(X) <br/> rowMaxs(X) <br/> rowMins(X)
705-
cumsum() | Column prefix-sum (For row-prefix sum, use cumsum(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=3, cols=2) <br/> B = cumsum(A) <br/> The output matrix B = [[1, 2], [4, 6], [9, 12]]
705+
cumsum() | Column prefix-sum | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=3, cols=2) <br/> B = cumsum(A) <br/> The output matrix B = [[1, 2], [4, 6], [9, 12]]
706+
rowcumsum() | Row prefix-sum | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=2, cols=3) <br/> B = rowcumsum(A) <br/> The output matrix B = [[1, 3, 6], [4, 9, 15]]
706707
cumprod() | Column prefix-prod (For row-prefix prod, use cumprod(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=3, cols=2) <br/> B = cumprod(A) <br/> The output matrix B = [[1, 2], [3, 8], [15, 48]]
707708
cummin() | Column prefix-min (For row-prefix min, use cummin(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2", rows=3, cols=2) <br/> B = cummin(A) <br/> The output matrix B = [[3, 4], [1, 4], [1, 2]]
708709
cummax() | Column prefix-max (For row-prefix min, use cummax(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2", rows=3, cols=2) <br/> B = cummax(A) <br/> The output matrix B = [[3, 4], [3, 6], [5, 6]]

src/main/java/org/apache/sysds/api/PythonDMLScript.java

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,40 @@
2525
import org.apache.log4j.Logger;
2626
import org.apache.sysds.api.jmlc.Connection;
2727

28+
import org.apache.sysds.common.Types;
29+
import org.apache.sysds.runtime.DMLRuntimeException;
30+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
31+
import org.apache.sysds.runtime.util.CommonThreadPool;
32+
import org.apache.sysds.runtime.util.UnixPipeUtils;
2833
import py4j.DefaultGatewayServerListener;
2934
import py4j.GatewayServer;
3035
import py4j.Py4JNetworkException;
3136

37+
import java.io.BufferedInputStream;
38+
import java.io.BufferedOutputStream;
39+
import java.io.IOException;
40+
import java.util.ArrayList;
41+
import java.util.HashMap;
42+
import java.util.List;
43+
import java.util.concurrent.Callable;
44+
import java.util.concurrent.ExecutionException;
45+
import java.util.concurrent.ExecutorService;
46+
import java.util.concurrent.Future;
47+
3248

3349
public class PythonDMLScript {
3450

3551
private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName());
3652
final private Connection _connection;
3753
public static GatewayServer GwS;
3854

55+
private static String fromPythonBase = "py2java";
56+
private static String toPythonBase = "java2py";
57+
public HashMap<Integer, BufferedInputStream> fromPython = null;
58+
public HashMap<Integer, BufferedOutputStream> toPython = null;
59+
public String baseDir;
60+
private static int BATCH_SIZE = 32*1024;
61+
3962
/**
4063
* Entry point for Python API.
4164
*
@@ -78,6 +101,103 @@ public Connection getConnection() {
78101
return _connection;
79102
}
80103

104+
105+
public void openPipes(String path, int num) throws IOException {
106+
fromPython = new HashMap<>(num * 2);
107+
toPython = new HashMap<>(num * 2);
108+
baseDir = path;
109+
for (int i = 0; i < num; i++) {
110+
BufferedInputStream pipe_in = UnixPipeUtils.openInput(path + "/" + fromPythonBase + "-" + i, i);
111+
LOG.debug("PY2JAVA pipe "+i+" is ready!");
112+
fromPython.put(i, pipe_in);
113+
114+
BufferedOutputStream pipe_out = UnixPipeUtils.openOutput(path + "/" + toPythonBase + "-" + i, i);
115+
toPython.put(i, pipe_out);
116+
}
117+
}
118+
119+
public MatrixBlock startReadingMbFromPipe(int id, int rlen, int clen, Types.ValueType type) throws IOException {
120+
long limit = (long) rlen * clen;
121+
LOG.debug("trying to read matrix from "+id+" with "+rlen+" rows and "+clen+" columns. Total size: "+limit);
122+
if(limit > Integer.MAX_VALUE)
123+
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
124+
" cannot be converted to MatrixBlock");
125+
MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
126+
if(fromPython != null){
127+
BufferedInputStream pipe = fromPython.get(id);
128+
double[] denseBlock = new double[(int) limit];
129+
UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, (int) limit, type, denseBlock, 0);
130+
mb.init(denseBlock, rlen, clen);
131+
} else {
132+
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
133+
}
134+
mb.recomputeNonZeros();
135+
mb.examSparsity();
136+
LOG.debug("Reading from Python finished");
137+
return mb;
138+
}
139+
140+
public MatrixBlock startReadingMbFromPipes(int[] blockSizes, int rlen, int clen, Types.ValueType type) throws ExecutionException, InterruptedException {
141+
long limit = (long) rlen * clen;
142+
if(limit > Integer.MAX_VALUE)
143+
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
144+
" cannot be converted to MatrixBlock");
145+
MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
146+
if(fromPython != null){
147+
ExecutorService pool = CommonThreadPool.get();
148+
double[] denseBlock = new double[(int) limit];
149+
int offsetOut = 0;
150+
List<Future<Void>> futures = new ArrayList<>();
151+
for (int i = 0; i < blockSizes.length; i++) {
152+
BufferedInputStream pipe = fromPython.get(i);
153+
int id = i, blockSize = blockSizes[i], _offsetOut = offsetOut;
154+
Callable<Void> task = () -> {
155+
UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, blockSize, type, denseBlock, _offsetOut);
156+
return null;
157+
};
158+
159+
futures.add(pool.submit(task));
160+
offsetOut += blockSize;
161+
}
162+
// Wait for all tasks and propagate exceptions
163+
for (Future<Void> f : futures) {
164+
f.get();
165+
}
166+
167+
mb.init(denseBlock, rlen, clen);
168+
} else {
169+
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
170+
}
171+
mb.recomputeNonZeros();
172+
mb.examSparsity();
173+
return mb;
174+
}
175+
176+
public void startWritingMbToPipe(int id, MatrixBlock mb) throws IOException {
177+
if (toPython != null) {
178+
int rlen = mb.getNumRows();
179+
int clen = mb.getNumColumns();
180+
int numElem = rlen * clen;
181+
LOG.debug("Trying to write matrix ["+baseDir + "-"+ id+"] with "+rlen+" rows and "+clen+" columns. Total size: "+numElem*8);
182+
183+
BufferedOutputStream out = toPython.get(id);
184+
long bytes = UnixPipeUtils.writeNumpyArrayInBatches(out, id, BATCH_SIZE, numElem, Types.ValueType.FP64, mb);
185+
186+
LOG.debug("Writing of " + bytes +" Bytes to Python ["+baseDir + "-"+ id+"] finished");
187+
} else {
188+
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
189+
}
190+
}
191+
192+
public void closePipes() throws IOException {
193+
LOG.debug("Closing all pipes in Java");
194+
for (BufferedInputStream pipe : fromPython.values())
195+
pipe.close();
196+
for (BufferedOutputStream pipe : toPython.values())
197+
pipe.close();
198+
LOG.debug("Closed all pipes in Java");
199+
}
200+
81201
protected static class DMLGateWayListener extends DefaultGatewayServerListener {
82202
private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName());
83203

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ public enum Builtins {
291291
ROLL("roll", false),
292292
ROUND("round", false),
293293
ROW_COUNT_DISTINCT("rowCountDistinct",false),
294+
ROWCUMSUM("rowcumsum", false),
294295
ROWINDEXMAX("rowIndexMax", false),
295296
ROWINDEXMIN("rowIndexMin", false),
296297
ROWMAX("rowMaxs", false),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public enum Opcodes {
3434
UAKP("uak+", InstructionType.AggregateUnary),
3535
UARKP("uark+", InstructionType.AggregateUnary),
3636
UACKP("uack+", InstructionType.AggregateUnary),
37+
UARCKP("uarck+", InstructionType.AggregateUnary),
3738
UASQKP("uasqk+", InstructionType.AggregateUnary),
3839
UARSQKP("uarsqk+", InstructionType.AggregateUnary),
3940
UACSQKP("uacsqk+", InstructionType.AggregateUnary),
@@ -151,6 +152,7 @@ public enum Opcodes {
151152
CEIL("ceil", InstructionType.Unary),
152153
FLOOR("floor", InstructionType.Unary),
153154
UCUMKP("ucumk+", InstructionType.Unary),
155+
UROWCUMKP("urowcumk+", InstructionType.Unary),
154156
UCUMM("ucum*", InstructionType.Unary),
155157
UCUMKPM("ucumk+*", InstructionType.Unary),
156158
UCUMMIN("ucummin", InstructionType.Unary),
@@ -383,6 +385,7 @@ public enum Opcodes {
383385
UCUMACMIN("ucumacmin", InstructionType.CumsumAggregate),
384386
UCUMACMAX("ucumacmax", InstructionType.CumsumAggregate),
385387
BCUMOFFKP("bcumoffk+", InstructionType.CumsumOffset),
388+
BROWCUMOFFKP("browcumoffk+", InstructionType.CumsumOffset),
386389
BCUMOFFM("bcumoff*", InstructionType.CumsumOffset),
387390
BCUMOFFPM("bcumoff+*", InstructionType.CumsumOffset),
388391
BCUMOFFMIN("bcumoffmin", InstructionType.CumsumOffset),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ public enum OpOp1 {
547547
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
548548
CUMSUMPROD, DET, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
549549
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
550-
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
550+
MEDIAN, PREFETCH, PRINT, ROUND, ROWCUMSUM, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
551551
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
552552
//fused ML-specific operators for performance
553553
SPROP, //sample proportion: P * (1 - P)
@@ -591,6 +591,7 @@ public String toString() {
591591
case MULT2: return Opcodes.MULT2.toString();
592592
case NOT: return Opcodes.NOT.toString();
593593
case POW2: return Opcodes.POW2.toString();
594+
case ROWCUMSUM: return Opcodes.UROWCUMKP.toString();
594595
case TYPEOF: return Opcodes.TYPEOF.toString();
595596
default: return name().toLowerCase();
596597
}
@@ -610,6 +611,7 @@ public static OpOp1 valueOfByOpcode(String opcode) {
610611
case "ucummin": return CUMMIN;
611612
case "ucum*": return CUMPROD;
612613
case "ucumk+": return CUMSUM;
614+
case "urowcumk+": return ROWCUMSUM;
613615
case "ucumk+*": return CUMSUMPROD;
614616
case "detectSchema": return DETECTSCHEMA;
615617
case "*2": return MULT2;

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
205205
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
206206
hi = simplifyMatrixScalarPMOperation(hop, hi, i); //e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
207207
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
208+
hi = simplifyTransposedCumsum(hop, hi, i); //e.g., t(cumsum(t(X))) -> rowcumsum(X)
208209

209210
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
210211
if( !descendFirst )
@@ -214,6 +215,28 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
214215
hop.setVisited();
215216
}
216217

218+
private static Hop simplifyTransposedCumsum( Hop parent, Hop hi, int pos )
219+
{
220+
//e.g., t(cumsum(t(X))) -> rowcumsum(X)
221+
if( HopRewriteUtils.isTransposeOperation(hi)
222+
&& hi.getInput(0) instanceof UnaryOp
223+
&& ((UnaryOp)hi.getInput(0)).getOp() == OpOp1.CUMSUM
224+
&& hi.getInput(0).getParent().size() == 1
225+
&& HopRewriteUtils.isTransposeOperation(hi.getInput(0).getInput(0), 1)) //inner transpose with single consumer
226+
{
227+
UnaryOp cumsum=(UnaryOp)hi.getInput(0);
228+
Hop innerMatrix = cumsum.getInput(0).getInput(0);
229+
230+
UnaryOp rowcumsumOp = HopRewriteUtils.createUnary(innerMatrix, OpOp1.ROWCUMSUM);
231+
HopRewriteUtils.replaceChildReference(parent,hi, rowcumsumOp, pos);
232+
233+
hi = rowcumsumOp;
234+
LOG.debug("Applied simplifyTransposedCumsum (line "+hi.getBeginLine()+").");
235+
}
236+
237+
return hi;
238+
}
239+
217240
private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) {
218241
if (!(hi instanceof BinaryOp))
219242
return hi;

0 commit comments

Comments
 (0)