Skip to content

Commit edfce10

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3729] New roll reorg operations in CP, incl tests
Closes #2103.
1 parent c940502 commit edfce10

File tree

16 files changed

+473
-39
lines changed

16 files changed

+473
-39
lines changed

.github/workflows/javaTests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,4 @@ jobs:
177177
name: Java Code Coverage (Jacoco)
178178
path: target/site/jacoco
179179
retention-days: 3
180+

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ public enum Builtins {
282282
RCM("rowClassMeet", "rcm", false, false, ReturnType.MULTI_RETURN),
283283
REMOVE("remove", false, ReturnType.MULTI_RETURN),
284284
REV("rev", false),
285+
ROLL("roll", false),
285286
ROUND("round", false),
286287
ROW_COUNT_DISTINCT("rowCountDistinct",false),
287288
ROWINDEXMAX("rowIndexMax", false),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ public boolean isCellOp() {
749749
/** Operations that perform internal reorganization of an allocation */
750750
public enum ReOrgOp {
751751
DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if sizes unknown
752-
RESHAPE, REV, SORT, TRANS;
752+
RESHAPE, REV, ROLL, SORT, TRANS;
753753

754754
@Override
755755
public String toString() {

src/main/java/org/apache/sysds/hops/ReorgOp.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ public void checkArity() {
9090
case REV:
9191
HopsException.check(sz == 1, this, "should have arity 1 for op %s but has arity %d", _op, sz);
9292
break;
93+
case ROLL:
94+
HopsException.check(sz == 2, this, "should have arity 2 for op %s but has arity %d", _op, sz);
95+
break;
9396
case RESHAPE:
9497
case SORT:
9598
HopsException.check(sz == 5, this, "should have arity 5 for op %s but has arity %d", _op, sz);
@@ -125,6 +128,7 @@ public boolean isGPUEnabled() {
125128
}
126129
case DIAG:
127130
case REV:
131+
case ROLL:
128132
case SORT:
129133
return false;
130134
default:
@@ -175,6 +179,18 @@ else if( getDim1()==1 && getDim2()==1 )
175179
setLops(transform1);
176180
break;
177181
}
182+
case ROLL: {
183+
Lop[] linputs = new Lop[2]; //input, shift
184+
for (int i = 0; i < 2; i++)
185+
linputs[i] = getInput().get(i).constructLops();
186+
187+
Transform transform1 = new Transform(linputs, _op, getDataType(), getValueType(), et, 1);
188+
189+
setOutputDimensions(transform1);
190+
setLineNumbers(transform1);
191+
setLops(transform1);
192+
break;
193+
}
178194
case RESHAPE: {
179195
Lop[] linputs = new Lop[5]; //main, rows, cols, dims, byrow
180196
for (int i = 0; i < 5; i++)
@@ -279,9 +295,10 @@ protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
279295
ret = new MatrixCharacteristics(dc.getCols(), dc.getRows(), -1, dc.getNonZeros());
280296
break;
281297
}
282-
case REV: {
298+
case REV:
299+
case ROLL: {
283300
// dims and nnz are exactly the same as in input
284-
if( dc.dimsKnown() )
301+
if (dc.dimsKnown())
285302
ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros());
286303
break;
287304
}
@@ -397,6 +414,7 @@ public void refreshSizeInformation()
397414
break;
398415
}
399416
case REV:
417+
case ROLL:
400418
{
401419
// dims and nnz are exactly the same as in input
402420
setDim1(input1.getDim1());

src/main/java/org/apache/sysds/lops/Transform.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ private String getOpcode() {
111111
case REV:
112112
// Transpose a matrix
113113
return "rev";
114-
114+
115+
case ROLL:
116+
return "roll";
117+
115118
case DIAG:
116119
// Transform a vector into a diagonal matrix
117120
return "rdiag";
@@ -138,6 +141,12 @@ public String getInstructions(String input1, String output) {
138141
return getInstructions(input1, 1, output);
139142
}
140143

144+
@Override
145+
public String getInstructions(String input1, String input2, String output) {
146+
//opcodes: roll
147+
return getInstructions(input1, 2, output);
148+
}
149+
141150
@Override
142151
public String getInstructions(String input1, String input2, String input3, String input4, String output) {
143152
//opcodes: rsort

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,17 @@ else if( getOpCode() == Builtins.RBIND ) {
12661266
output.setBlocksize (id.getBlocksize());
12671267
output.setValueType(id.getValueType());
12681268
break;
1269-
1269+
1270+
case ROLL:
1271+
checkNumParameters(2);
1272+
checkMatrixParam(getFirstExpr());
1273+
checkScalarParam(getSecondExpr());
1274+
output.setDataType(DataType.MATRIX);
1275+
output.setDimensions(id.getDim1(), id.getDim2());
1276+
output.setBlocksize(id.getBlocksize());
1277+
output.setValueType(id.getValueType());
1278+
break;
1279+
12701280
case DIAG:
12711281
checkNumParameters(1);
12721282
checkMatrixParam(getFirstExpr());

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,6 +2481,14 @@ else if ( sop.equalsIgnoreCase("!=") )
24812481
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr);
24822482
break;
24832483

2484+
case ROLL:
2485+
ArrayList<Hop> inputs = new ArrayList<>();
2486+
inputs.add(expr);
2487+
inputs.add(expr2);
2488+
currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX,
2489+
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), inputs);
2490+
break;
2491+
24842492
case CBIND:
24852493
case RBIND:
24862494
OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.functionobjects;
21+
22+
import org.apache.commons.lang3.NotImplementedException;
23+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
24+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
25+
import org.apache.sysds.runtime.meta.DataCharacteristics;
26+
27+
/**
28+
* This index function is NOT used for actual sorting but just as a reference
29+
* in ReorgOperator in order to identify sort operations.
30+
*/
31+
public class RollIndex extends IndexFunction {
32+
private static final long serialVersionUID = -8446389232078905200L;
33+
34+
private final int _shift;
35+
36+
public RollIndex(int shift) {
37+
_shift = shift;
38+
}
39+
40+
public int getShift() {
41+
return _shift;
42+
}
43+
44+
@Override
45+
public boolean computeDimension(int row, int col, CellIndex retDim) {
46+
retDim.set(row, col);
47+
return false;
48+
}
49+
50+
@Override
51+
public boolean computeDimension(DataCharacteristics in, DataCharacteristics out) {
52+
out.set(in.getRows(), in.getCols(), in.getBlocksize(), in.getNonZeros());
53+
return false;
54+
}
55+
56+
@Override
57+
public void execute(MatrixIndexes in, MatrixIndexes out) {
58+
throw new NotImplementedException();
59+
}
60+
61+
@Override
62+
public void execute(CellIndex in, CellIndex out) {
63+
throw new NotImplementedException();
64+
}
65+
}

src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ public class CPInstructionParser extends InstructionParser {
271271
// Reorg Instruction Opcodes (repositioning of existing values)
272272
String2CPInstructionType.put( "r'" , CPType.Reorg);
273273
String2CPInstructionType.put( "rev" , CPType.Reorg);
274+
String2CPInstructionType.put( "roll" , CPType.Reorg);
274275
String2CPInstructionType.put( "rdiag" , CPType.Reorg);
275276
String2CPInstructionType.put( "rshape" , CPType.Reshape);
276277
String2CPInstructionType.put( "rsort" , CPType.Reorg);

src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
2626
import org.apache.sysds.runtime.functionobjects.DiagIndex;
2727
import org.apache.sysds.runtime.functionobjects.RevIndex;
28+
import org.apache.sysds.runtime.functionobjects.RollIndex;
2829
import org.apache.sysds.runtime.functionobjects.SortIndex;
2930
import org.apache.sysds.runtime.functionobjects.SwapIndex;
3031
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -38,51 +39,58 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
3839
private final CPOperand _col;
3940
private final CPOperand _desc;
4041
private final CPOperand _ixret;
42+
private final CPOperand _shift;
4143

4244
/**
4345
* for opcodes r' and rdiag
44-
*
45-
* @param op
46-
* operator
47-
* @param in
48-
* cp input operand
49-
* @param out
50-
* cp output operand
51-
* @param opcode
52-
* the opcode
53-
* @param istr
54-
* ?
46+
*
47+
* @param op operator
48+
* @param in cp input operand
49+
* @param out cp output operand
50+
* @param opcode the opcode
51+
* @param istr ?
5552
*/
5653
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
5754
this(op, in, out, null, null, null, opcode, istr);
5855
}
5956

6057
/**
6158
* for opcode rsort
62-
*
63-
* @param op
64-
* operator
65-
* @param in
66-
* cp input operand
67-
* @param col
68-
* ?
69-
* @param desc
70-
* ?
71-
* @param ixret
72-
* ?
73-
* @param out
74-
* cp output operand
75-
* @param opcode
76-
* the opcode
77-
* @param istr
78-
* ?
59+
*
60+
* @param op operator
61+
* @param in cp input operand
62+
* @param col ?
63+
* @param desc ?
64+
* @param ixret ?
65+
* @param out cp output operand
66+
* @param opcode the opcode
67+
* @param istr ?
7968
*/
8069
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand col, CPOperand desc, CPOperand ixret,
81-
String opcode, String istr) {
70+
String opcode, String istr) {
8271
super(CPType.Reorg, op, in, out, opcode, istr);
8372
_col = col;
8473
_desc = desc;
8574
_ixret = ixret;
75+
_shift = null;
76+
}
77+
78+
/**
79+
* for opcode roll
80+
*
81+
* @param op operator
82+
* @param in cp input operand
83+
* @param shift ?
84+
* @param out cp output operand
85+
* @param opcode the opcode
86+
* @param istr ?
87+
*/
88+
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
89+
super(CPType.Reorg, op, in, out, opcode, istr);
90+
_col = null;
91+
_desc = null;
92+
_ixret = null;
93+
_shift = shift;
8694
}
8795

8896
public static ReorgCPInstruction parseInstruction ( String str ) {
@@ -103,6 +111,13 @@ else if ( opcode.equalsIgnoreCase("rev") ) {
103111
parseUnaryInstruction(str, in, out); //max 2 operands
104112
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
105113
}
114+
else if (opcode.equalsIgnoreCase("roll")) {
115+
InstructionUtils.checkNumFields(str, 3);
116+
in.split(parts[1]);
117+
out.split(parts[3]);
118+
CPOperand shift = new CPOperand(parts[2]);
119+
return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0)), in, out, shift, opcode, str);
120+
}
106121
else if ( opcode.equalsIgnoreCase("rdiag") ) {
107122
parseUnaryInstruction(str, in, out); //max 2 operands
108123
return new ReorgCPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -136,7 +151,12 @@ public void processInstruction(ExecutionContext ec) {
136151
boolean ixret = ec.getScalarInput(_ixret).getBooleanValue();
137152
r_op = r_op.setFn(new SortIndex(cols, desc, ixret));
138153
}
139-
154+
155+
if (r_op.fn instanceof RollIndex) {
156+
int shift = (int) ec.getScalarInput(_shift).getLongValue();
157+
r_op = r_op.setFn(new RollIndex(shift));
158+
}
159+
140160
//execute operation
141161
MatrixBlock soresBlock = matBlock.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
142162

0 commit comments

Comments
 (0)