Skip to content

Commit f27b515

Browse files
author
Biranavan Parameswaran
committed
[SYSTEMDS-3730] Multithreaded roll operation and improved tests
This patch introduces multi-threading support for the roll operation to improve performance. The RollTest.java has been updated to cover both single and multithreaded execution modes. Furthermore, this update adds comprehensive consistency checks to ensure mathematical correctness. New tests were created to validate both dense and sparse matrix inputs. Additionally, cross-verification tests were added to confirm that sparse and dense rolling for single and multithreaded executions produce identical results.
1 parent d9f6c6d commit f27b515

File tree

8 files changed

+640
-9
lines changed

8 files changed

+640
-9
lines changed

src/main/java/org/apache/sysds/hops/ReorgOp.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ _op, getDataType(), getValueType(), et,
173173
for (int i = 0; i < 2; i++)
174174
linputs[i] = getInput().get(i).constructLops();
175175

176-
Transform transform1 = new Transform(linputs, _op, getDataType(), getValueType(), et, 1);
176+
Transform transform1 = new Transform(
177+
linputs, _op, getDataType(), getValueType(), et,
178+
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
177179

178180
setOutputDimensions(transform1);
179181
setLineNumbers(transform1);

src/main/java/org/apache/sysds/lops/Transform.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ private String getInstructions(String input1, int numInputs, String output) {
180180
sb.append( this.prepOutputOperand(output));
181181

182182
if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED || getExecType()==ExecType.OOC)
183-
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
183+
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT || _operation == ReOrgOp.ROLL) ) {
184184
sb.append( OPERAND_DELIMITOR );
185185
sb.append( _numThreads );
186186
if ( getExecType()==ExecType.FED ) {

src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,13 @@ else if ( opcode.equalsIgnoreCase(Opcodes.REV.toString()) ) {
118118
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject(), k), in, out, opcode, str);
119119
}
120120
else if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) {
121-
InstructionUtils.checkNumFields(str, 3);
121+
InstructionUtils.checkNumFields(str, 3, 4);
122122
in.split(parts[1]);
123123
out.split(parts[3]);
124124
CPOperand shift = new CPOperand(parts[2]);
125-
return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0)), in, out, shift, opcode, str);
125+
int k = (parts.length > 4) ? Integer.parseInt(parts[4]) : 1;
126+
127+
return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0), k), in, out, shift, opcode, str);
126128
}
127129
else if ( opcode.equalsIgnoreCase(Opcodes.DIAG.toString()) ) {
128130
parseUnaryInstruction(str, in, out); //max 2 operands

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ public static MatrixBlock reorg( MatrixBlock in, MatrixBlock out, ReorgOperator
134134
return rev(in, out);
135135
case ROLL:
136136
RollIndex rix = (RollIndex) op.fn;
137+
if(op.getNumThreads() > 1)
138+
return roll(in, out, rix.getShift(), op.getNumThreads());
137139
return roll(in, out, rix.getShift());
138140
case DIAG:
139141
return diag(in, out);
@@ -514,6 +516,119 @@ public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int shift) {
514516
return out;
515517
}
516518

519+
public static MatrixBlock roll(MatrixBlock input, MatrixBlock output, int shift, int numThreads) {
520+
521+
final int numRows = input.rlen;
522+
final int numCols = input.clen;
523+
final boolean isSparse = input.sparse;
524+
525+
// sparse-safe operation
526+
if(input.isEmptyBlock(false))
527+
return output;
528+
529+
// special case: row vector
530+
if(numRows == 1) {
531+
output.copy(input);
532+
return output;
533+
}
534+
535+
if(numThreads <= 1 || input.isEmptyBlock(false) || input.getLength() < PAR_NUMCELL_THRESHOLD) {
536+
return roll(input, output, shift); // fallback to single-threaded
537+
}
538+
539+
final int normalizedShift = getNormalizedShiftForRoll(shift, numRows);
540+
541+
output.reset(numRows, numCols, isSparse);
542+
output.nonZeros = input.nonZeros;
543+
544+
if(isSparse) {
545+
output.allocateSparseRowsBlock(false);
546+
}
547+
else {
548+
output.allocateDenseBlock(false);
549+
}
550+
551+
ExecutorService threadPool = CommonThreadPool.get(numThreads);
552+
try {
553+
final int rowsPerThread = (int) Math.ceil((double) numRows / numThreads);
554+
List<Future<?>> tasks = new ArrayList<>();
555+
556+
for(int threadIndex = 0; threadIndex < numThreads; threadIndex++) {
557+
558+
final int startRow = threadIndex * rowsPerThread;
559+
final int endRow = Math.min((threadIndex + 1) * rowsPerThread, numRows);
560+
561+
tasks.add(threadPool.submit(() -> {
562+
if(isSparse)
563+
rollSparseBlock(input, output, normalizedShift, startRow, endRow);
564+
else
565+
rollDenseBlock(input, output, normalizedShift, startRow, endRow);
566+
}));
567+
}
568+
569+
for(Future<?> task : tasks)
570+
task.get();
571+
572+
}
573+
catch(Exception ex) {
574+
throw new DMLRuntimeException(ex);
575+
}
576+
finally {
577+
threadPool.shutdown();
578+
}
579+
580+
return output;
581+
}
582+
583+
private static int getNormalizedShiftForRoll(int shift, int numRows) {
584+
shift = shift % numRows;
585+
if(shift < 0)
586+
shift += numRows;
587+
588+
return shift;
589+
}
590+
591+
private static void rollDenseBlock(MatrixBlock input, MatrixBlock output, int shift, int startRow, int endRow) {
592+
593+
DenseBlock inputBlock = input.getDenseBlock();
594+
DenseBlock outputBlock = output.getDenseBlock();
595+
final int numRows = input.rlen;
596+
final int numCols = input.clen;
597+
598+
for(int targetRow = startRow; targetRow < endRow; targetRow++) {
599+
int sourceRow = targetRow - shift;
600+
if(sourceRow < 0)
601+
sourceRow += numRows;
602+
603+
System.arraycopy(inputBlock.values(sourceRow), inputBlock.pos(sourceRow), outputBlock.values(targetRow),
604+
outputBlock.pos(targetRow), numCols);
605+
}
606+
}
607+
608+
private static void rollSparseBlock(MatrixBlock input, MatrixBlock output, int shift, int startRow, int endRow) {
609+
610+
SparseBlock inputBlock = input.getSparseBlock();
611+
SparseBlock outputBlock = output.getSparseBlock();
612+
final int numRows = input.rlen;
613+
614+
for(int targetRow = startRow; targetRow < endRow; targetRow++) {
615+
int sourceRow = targetRow - shift;
616+
if(sourceRow < 0)
617+
sourceRow += numRows;
618+
619+
if(!inputBlock.isEmpty(sourceRow)) {
620+
int rowStart = inputBlock.pos(sourceRow);
621+
int rowEnd = rowStart + inputBlock.size(sourceRow);
622+
int[] colIndexes = inputBlock.indexes(sourceRow);
623+
double[] values = inputBlock.values(sourceRow);
624+
625+
for(int k = rowStart; k < rowEnd; k++) {
626+
outputBlock.set(targetRow, colIndexes[k], values[k]);
627+
}
628+
}
629+
}
630+
}
631+
517632
public static void roll(IndexedMatrixValue in, long rlen, int blen, int shift, ArrayList<IndexedMatrixValue> out) {
518633
MatrixIndexes inMtxIdx = in.getIndexes();
519634
MatrixBlock inMtxBlk = (MatrixBlock) in.getValue();
@@ -2554,15 +2669,15 @@ private static void reverseSparse(MatrixBlock in, MatrixBlock out, int rl, int r
25542669

25552670
private static void rollDense(MatrixBlock in, MatrixBlock out, int shift) {
25562671
final int m = in.rlen;
2557-
shift %= (m != 0 ? m : 1); // roll matrix with axis=none
2672+
shift = getNormalizedShiftForRoll(shift, m); // roll matrix with axis=none
25582673

25592674
copyDenseMtx(in, out, 0, shift, m - shift, false, true);
25602675
copyDenseMtx(in, out, m - shift, 0, shift, true, true);
25612676
}
25622677

25632678
private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) {
25642679
final int m = in.rlen;
2565-
shift %= (m != 0 ? m : 1); // roll matrix with axis=0
2680+
shift = getNormalizedShiftForRoll(shift, m); // roll matrix with axis=0
25662681

25672682
copySparseMtx(in, out, 0, shift, m - shift, false, true);
25682683
copySparseMtx(in, out, m-shift, 0, shift, false, true);
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.matrix.libMatrixReorg;
21+
22+
import java.util.Arrays;
23+
import java.util.Collection;
24+
25+
import org.apache.sysds.runtime.functionobjects.IndexFunction;
26+
import org.apache.sysds.runtime.functionobjects.RollIndex;
27+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
28+
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
29+
import org.apache.sysds.test.TestUtils;
30+
import org.junit.Test;
31+
import org.junit.runner.RunWith;
32+
import org.junit.runners.Parameterized;
33+
34+
@RunWith(Parameterized.class)
35+
public class DenseMatrixRollOperationCorrectnessTest {
36+
37+
private final double[][] input;
38+
private final double[][] expected;
39+
private final int shift;
40+
41+
public DenseMatrixRollOperationCorrectnessTest(double[][] input, double[][] expected, int shift) {
42+
this.input = input;
43+
this.expected = expected;
44+
this.shift = shift;
45+
}
46+
47+
@Parameterized.Parameters(name = "Shift={2}, Size={0}x{1}")
48+
public static Collection<Object[]> data() {
49+
return Arrays.asList(new Object[][] {
50+
{
51+
new double[][] {{1, 2, 3, 4, 5}},
52+
new double[][] {{1, 2, 3, 4, 5}},
53+
0
54+
},
55+
{
56+
new double[][] {{1, 2, 3, 4, 5}},
57+
new double[][] {{1, 2, 3, 4, 5}},
58+
1
59+
},
60+
{
61+
new double[][] {{1, 2, 3, 4, 5}},
62+
new double[][] {{1, 2, 3, 4, 5}},
63+
-3
64+
},
65+
{
66+
new double[][] {{1, 2, 3, 4, 5}},
67+
new double[][] {{1, 2, 3, 4, 5}},
68+
999
69+
},
70+
{
71+
new double[][] {{1}, {2}, {3}, {4}, {5}},
72+
new double[][] {{4}, {5}, {1}, {2}, {3}},
73+
2
74+
},
75+
{
76+
new double[][] {{1}, {2}, {3}, {4}, {5}},
77+
new double[][] {{2}, {3}, {4}, {5}, {1}},
78+
-1
79+
},
80+
{
81+
new double[][] {{1}, {2}, {3}, {4}, {5}},
82+
new double[][] {{1}, {2}, {3}, {4}, {5}},
83+
5
84+
},
85+
{
86+
new double[][] {{1, 2, 3}, {4, 5, 6}},
87+
new double[][] {{4, 5, 6}, {1, 2, 3}},
88+
1
89+
},
90+
{
91+
new double[][] {{1, 2, 3}, {4, 5, 6}},
92+
new double[][] {{4, 5, 6}, {1, 2, 3}},
93+
7
94+
},
95+
{
96+
new double[][] {{1, 2, 3}, {4, 5, 6}},
97+
new double[][] {{1, 2, 3}, {4, 5, 6}},
98+
2
99+
},
100+
{
101+
new double[][] {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}},
102+
new double[][] {{7, 8, 9}, {1, 2, 3}, {4, 5, 6}},
103+
1
104+
},
105+
{
106+
new double[][] {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}},
107+
new double[][] {{4, 5, 6}, {7, 8, 9}, {1, 2, 3}},
108+
-1
109+
},
110+
{
111+
new double[][] {{9, 8, 7}, {6, 5, 4}, {3, 2, 1}},
112+
new double[][] {{3, 2, 1}, {9, 8, 7}, {6, 5, 4}},
113+
1
114+
},
115+
{
116+
new double[][] {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
117+
new double[][] {{9, 10, 11, 12}, {1, 2, 3, 4}, {5, 6, 7, 8}},
118+
1
119+
},
120+
{
121+
new double[][] {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
122+
new double[][] {{5, 6, 7, 8}, {9, 10, 11, 12}, {1, 2, 3, 4}},
123+
-1
124+
},
125+
{
126+
new double[][] {{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}, {16, 17, 18, 19, 20}, {21, 22, 23, 24, 25}},
127+
new double[][] {{21, 22, 23, 24, 25}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}, {16, 17, 18, 19, 20}},
128+
1
129+
},
130+
{
131+
new double[][] {{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}, {16, 17, 18, 19, 20}, {21, 22, 23, 24, 25}},
132+
new double[][] {{11, 12, 13, 14, 15}, {16, 17, 18, 19, 20}, {21, 22, 23, 24, 25}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
133+
-2
134+
},
135+
{
136+
new double[][] {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}, {13, 14, 15}, {16, 17, 18}, {19, 20, 21},
137+
{22, 23, 24}, {25, 26, 27}, {28, 29, 30}},
138+
new double[][] {{22, 23, 24}, {25, 26, 27}, {28, 29, 30}, {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12},
139+
{13, 14, 15}, {16, 17, 18}, {19, 20, 21}},
140+
3
141+
},
142+
{
143+
new double[][] {{1, 2}, {3, 4}, {5, 6}, {7, 8}},
144+
new double[][] {{5, 6}, {7, 8}, {1, 2}, {3, 4}},
145+
1002
146+
},
147+
{
148+
new double[][] {{1}, {2}, {3}, {4}, {5}},
149+
new double[][] {{3}, {4}, {5}, {1}, {2}},
150+
-12
151+
},
152+
{
153+
new double[][] {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}},
154+
new double[][] {{4, 5, 6}, {7, 8, 9}, {1, 2, 3}},
155+
-10
156+
},
157+
{
158+
new double[][] {{1, 2}, {3, 4}, {5, 6}, {7, 8}},
159+
new double[][] {{1, 2}, {3, 4}, {5, 6}, {7, 8}},
160+
-4
161+
},
162+
{
163+
new double[][] {{1, 2}, {3, 4}, {5, 6}, {7, 8}},
164+
new double[][] {{3, 4}, {5, 6}, {7, 8}, {1, 2}},
165+
-5
166+
}
167+
});
168+
}
169+
170+
@Test
171+
public void testRollOperationProducesExpectedOutput() {
172+
MatrixBlock inBlock = new MatrixBlock(input.length, input[0].length, false);
173+
inBlock.init(input, input.length, input[0].length);
174+
175+
IndexFunction op = new RollIndex(shift);
176+
MatrixBlock outBlock = inBlock.reorgOperations(new ReorgOperator(op), new MatrixBlock(), 0, 0, 5);
177+
178+
MatrixBlock expectedBlock = new MatrixBlock(expected.length, expected[0].length, false);
179+
expectedBlock.init(expected, expected.length, expected[0].length);
180+
181+
TestUtils.compareMatrices(outBlock, expectedBlock, 1e-12, "Dense Roll operation does not match expected output");
182+
}
183+
}

0 commit comments

Comments
 (0)