Skip to content

Commit 470bb0b

Browse files
Biranavan Parameswaranmboehm7
authored andcommitted
[SYSTEMDS-3730] Multithreaded roll operation and improved tests
This patch introduces multi-threading support for the roll operation to improve performance. The RollTest.java has been updated to cover both single and multithreaded execution modes. Furthermore, this update adds comprehensive consistency checks to ensure mathematical correctness. New tests were created to validate both dense and sparse matrix inputs. Additionally, cross-verification tests were added to confirm that sparse and dense rolling for single and multithreaded executions produce identical results. Closes #2376.
1 parent eac7e0e commit 470bb0b

File tree

9 files changed

+774
-9
lines changed

9 files changed

+774
-9
lines changed

src/main/java/org/apache/sysds/hops/ReorgOp.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ _op, getDataType(), getValueType(), et,
173173
for (int i = 0; i < 2; i++)
174174
linputs[i] = getInput().get(i).constructLops();
175175

176-
Transform transform1 = new Transform(linputs, _op, getDataType(), getValueType(), et, 1);
176+
Transform transform1 = new Transform(
177+
linputs, _op, getDataType(), getValueType(), et,
178+
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
177179

178180
setOutputDimensions(transform1);
179181
setLineNumbers(transform1);

src/main/java/org/apache/sysds/lops/Transform.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ private String getInstructions(String input1, int numInputs, String output) {
180180
sb.append( this.prepOutputOperand(output));
181181

182182
if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED || getExecType()==ExecType.OOC)
183-
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
183+
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT || _operation == ReOrgOp.ROLL) ) {
184184
sb.append( OPERAND_DELIMITOR );
185185
sb.append( _numThreads );
186186
if ( getExecType()==ExecType.FED ) {

src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,13 @@ else if ( opcode.equalsIgnoreCase(Opcodes.REV.toString()) ) {
118118
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject(), k), in, out, opcode, str);
119119
}
120120
else if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) {
121-
InstructionUtils.checkNumFields(str, 3);
121+
InstructionUtils.checkNumFields(str, 3, 4);
122122
in.split(parts[1]);
123123
out.split(parts[3]);
124124
CPOperand shift = new CPOperand(parts[2]);
125-
return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0)), in, out, shift, opcode, str);
125+
int k = (parts.length > 4) ? Integer.parseInt(parts[4]) : 1;
126+
127+
return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0), k), in, out, shift, opcode, str);
126128
}
127129
else if ( opcode.equalsIgnoreCase(Opcodes.DIAG.toString()) ) {
128130
parseUnaryInstruction(str, in, out); //max 2 operands

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ public static MatrixBlock reorg( MatrixBlock in, MatrixBlock out, ReorgOperator
134134
return rev(in, out);
135135
case ROLL:
136136
RollIndex rix = (RollIndex) op.fn;
137+
if(op.getNumThreads() > 1)
138+
return roll(in, out, rix.getShift(), op.getNumThreads());
137139
return roll(in, out, rix.getShift());
138140
case DIAG:
139141
return diag(in, out);
@@ -514,6 +516,124 @@ public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int shift) {
514516
return out;
515517
}
516518

519+
public static MatrixBlock roll(MatrixBlock input, MatrixBlock output, int shift, int numThreads) {
520+
521+
final int numRows = input.rlen;
522+
final int numCols = input.clen;
523+
final boolean isSparse = input.sparse;
524+
525+
// sparse-safe operation
526+
if(input.isEmptyBlock(false))
527+
return output;
528+
529+
// special case: row vector
530+
if(numRows == 1) {
531+
output.copy(input);
532+
return output;
533+
}
534+
535+
if(numThreads <= 1 || input.getLength() < PAR_NUMCELL_THRESHOLD) {
536+
return roll(input, output, shift); // fallback to single-threaded
537+
}
538+
539+
final int normalizedShift = getNormalizedShiftForRoll(shift, numRows);
540+
541+
output.reset(numRows, numCols, isSparse);
542+
output.nonZeros = input.nonZeros;
543+
544+
if(isSparse) {
545+
output.allocateSparseRowsBlock(false);
546+
}
547+
else {
548+
output.allocateDenseBlock(false);
549+
}
550+
551+
//TODO experiment with more tasks per thread for better load balance
552+
//TODO call common kernel from both single- and multi-threaded execution
553+
554+
ExecutorService threadPool = CommonThreadPool.get(numThreads);
555+
try {
556+
final int rowsPerThread = (int) Math.ceil((double) numRows / numThreads);
557+
List<Future<?>> tasks = new ArrayList<>();
558+
559+
for(int threadIndex = 0; threadIndex < numThreads; threadIndex++) {
560+
561+
final int startRow = threadIndex * rowsPerThread;
562+
final int endRow = Math.min((threadIndex + 1) * rowsPerThread, numRows);
563+
564+
tasks.add(threadPool.submit(() -> {
565+
if(isSparse)
566+
rollSparseBlock(input, output, normalizedShift, startRow, endRow);
567+
else
568+
rollDenseBlock(input, output, normalizedShift, startRow, endRow);
569+
}));
570+
}
571+
572+
for(Future<?> task : tasks)
573+
task.get();
574+
575+
}
576+
catch(Exception ex) {
577+
throw new DMLRuntimeException(ex);
578+
}
579+
finally {
580+
threadPool.shutdown();
581+
}
582+
583+
return output;
584+
}
585+
586+
private static int getNormalizedShiftForRoll(int shift, int numRows) {
587+
shift = shift % numRows;
588+
if(shift < 0)
589+
shift += numRows;
590+
591+
return shift;
592+
}
593+
594+
private static void rollDenseBlock(MatrixBlock input, MatrixBlock output,
595+
int shift, int startRow, int endRow)
596+
{
597+
DenseBlock inputBlock = input.getDenseBlock();
598+
DenseBlock outputBlock = output.getDenseBlock();
599+
final int numRows = input.rlen;
600+
final int numCols = input.clen;
601+
602+
for(int targetRow = startRow; targetRow < endRow; targetRow++) {
603+
int sourceRow = targetRow - shift;
604+
if(sourceRow < 0)
605+
sourceRow += numRows;
606+
607+
System.arraycopy(inputBlock.values(sourceRow), inputBlock.pos(sourceRow), outputBlock.values(targetRow),
608+
outputBlock.pos(targetRow), numCols);
609+
}
610+
}
611+
612+
private static void rollSparseBlock(MatrixBlock input, MatrixBlock output,
613+
int shift, int startRow, int endRow)
614+
{
615+
SparseBlock inputBlock = input.getSparseBlock();
616+
SparseBlock outputBlock = output.getSparseBlock();
617+
final int numRows = input.rlen;
618+
619+
for(int targetRow = startRow; targetRow < endRow; targetRow++) {
620+
int sourceRow = targetRow - shift;
621+
if(sourceRow < 0)
622+
sourceRow += numRows;
623+
624+
if(!inputBlock.isEmpty(sourceRow)) {
625+
int rowStart = inputBlock.pos(sourceRow);
626+
int rowEnd = rowStart + inputBlock.size(sourceRow);
627+
int[] colIndexes = inputBlock.indexes(sourceRow);
628+
double[] values = inputBlock.values(sourceRow);
629+
630+
for(int k = rowStart; k < rowEnd; k++) {
631+
outputBlock.set(targetRow, colIndexes[k], values[k]);
632+
}
633+
}
634+
}
635+
}
636+
517637
public static void roll(IndexedMatrixValue in, long rlen, int blen, int shift, ArrayList<IndexedMatrixValue> out) {
518638
MatrixIndexes inMtxIdx = in.getIndexes();
519639
MatrixBlock inMtxBlk = (MatrixBlock) in.getValue();
@@ -2554,15 +2674,15 @@ private static void reverseSparse(MatrixBlock in, MatrixBlock out, int rl, int r
25542674

25552675
private static void rollDense(MatrixBlock in, MatrixBlock out, int shift) {
25562676
final int m = in.rlen;
2557-
shift %= (m != 0 ? m : 1); // roll matrix with axis=none
2677+
shift = getNormalizedShiftForRoll(shift, m); // roll matrix with axis=none
25582678

25592679
copyDenseMtx(in, out, 0, shift, m - shift, false, true);
25602680
copyDenseMtx(in, out, m - shift, 0, shift, true, true);
25612681
}
25622682

25632683
private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) {
25642684
final int m = in.rlen;
2565-
shift %= (m != 0 ? m : 1); // roll matrix with axis=0
2685+
shift = getNormalizedShiftForRoll(shift, m); // roll matrix with axis=0
25662686

25672687
copySparseMtx(in, out, 0, shift, m - shift, false, true);
25682688
copySparseMtx(in, out, m-shift, 0, shift, false, true);
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.performance.matrix;
21+
22+
import org.apache.sysds.performance.compression.APerfTest;
23+
import org.apache.sysds.performance.generators.ConstMatrix;
24+
import org.apache.sysds.performance.generators.IGenerate;
25+
import org.apache.sysds.runtime.functionobjects.IndexFunction;
26+
import org.apache.sysds.runtime.functionobjects.RollIndex;
27+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
28+
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
29+
import org.apache.sysds.test.TestUtils;
30+
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
31+
32+
import java.util.Random;
33+
34+
public class MatrixRollPerf extends APerfTest<Object, MatrixBlock> {
35+
36+
private final int rows;
37+
private final int cols;
38+
private final int shift;
39+
private final int k;
40+
41+
private final ReorgOperator reorg;
42+
private MatrixBlock out;
43+
44+
public MatrixRollPerf(int N, int W, IGenerate<MatrixBlock> gen, int rows, int cols, int shift, int k) {
45+
super(N, W, gen);
46+
this.rows = rows;
47+
this.cols = cols;
48+
this.shift = shift;
49+
this.k = k;
50+
51+
IndexFunction op = new RollIndex(shift);
52+
this.reorg = new ReorgOperator(op, k);
53+
}
54+
55+
public void run() throws Exception {
56+
MatrixBlock mb = gen.take();
57+
logInfos(rows, cols, shift, mb.getSparsity(), k);
58+
59+
60+
String info = String.format("rows: %5d cols: %5d sp: %.4f shift: %4d k: %2d",
61+
rows, cols, mb.getSparsity(), shift, k);
62+
63+
64+
warmup(this::rollOnce, W);
65+
66+
execute(this::rollOnce, info);
67+
}
68+
69+
private void logInfos(int rows, int cols, int shift, double sparsity, int k) {
70+
String matrixType = sparsity == 1 ? "Dense" : "Sparse";
71+
if (k == 1) {
72+
System.out.println("---------------------------------------------------------------------------------------------------------");
73+
System.out.printf("%s Experiment for rows %d columns %d and shift %d \n", matrixType, rows, cols, shift);
74+
System.out.println("---------------------------------------------------------------------------------------------------------");
75+
}
76+
}
77+
78+
private void rollOnce() {
79+
MatrixBlock in = gen.take();
80+
81+
if (out == null)
82+
out = new MatrixBlock(rows, cols, in.isInSparseFormat());
83+
84+
out.reset(rows, cols, in.isInSparseFormat());
85+
86+
in.reorgOperations(reorg, out, 0, 0, 0);
87+
88+
ret.add(null);
89+
}
90+
91+
@Override
92+
protected String makeResString() {
93+
return "";
94+
}
95+
96+
public static void main(String[] args) throws Exception {
97+
int kMulti = InfrastructureAnalyzer.getLocalParallelism();
98+
int reps = 2000;
99+
int warmup = 200;
100+
101+
int minRows = 2017;
102+
int minCols = 1001;
103+
double spSparse = 0.01;
104+
int minShift = -50;
105+
int maxShift = 1022;
106+
int iterations = 10;
107+
108+
Random rand = new Random(42);
109+
110+
for (int i = 0; i < iterations; i++) {
111+
int rows = 10_000_000;
112+
int cols = 10;
113+
int shift = rand.nextInt((maxShift - minShift) + 1) + minShift;
114+
115+
MatrixBlock denseIn = TestUtils.generateTestMatrixBlock(rows, cols, -100, 100, 1.0, 42);
116+
MatrixBlock sparseIn = TestUtils.generateTestMatrixBlock(rows, cols, -100, 100, spSparse, 42);
117+
118+
// Run Dense Case (Single vs Multi-threaded)
119+
new MatrixRollPerf(reps, warmup, new ConstMatrix(denseIn, -1), rows, cols, shift, 1).run();
120+
new MatrixRollPerf(reps, warmup, new ConstMatrix(denseIn, -1), rows, cols, shift, kMulti).run();
121+
122+
// Run Sparse Case (Single vs Multi-threaded)
123+
new MatrixRollPerf(reps, warmup, new ConstMatrix(sparseIn, -1), rows, cols, shift, 1).run();
124+
new MatrixRollPerf(reps, warmup, new ConstMatrix(sparseIn, -1), rows, cols, shift, kMulti).run();
125+
}
126+
}
127+
}

0 commit comments

Comments
 (0)