Skip to content

Commit d0287b1

Browse files
e-straussmboehm7
authored andcommitted
[SYSTEMDS-3772] Nary elementwise multiplication rewrite/runtime
- new runtime operations and rewrite for nary mult similar to add - fixed bug in nary hop, where output value type was computed wrongfully in case of scalars with mixed value types - fixed bug in nary hop, where output is wrongfully set as scalar - clean up & fixed the aggregate ternary rewrite to also work nary mult & fixed the n* lineage issue - increased code coverage for rewrite ternary aggregate: colsum(A^3) Closes #2105.
1 parent e0966e7 commit d0287b1

23 files changed

+632
-113
lines changed

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,10 +739,10 @@ public String toString() {
739739

740740
/** Operations that require a variable number of operands*/
741741
public enum OpOpN {
742-
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, EVAL, LIST;
742+
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;
743743

744744
public boolean isCellOp() {
745-
return this == MIN || this == MAX || this == PLUS;
745+
return this == MIN || this == MAX || this == PLUS || this == MULT;
746746
}
747747
}
748748

src/main/java/org/apache/sysds/hops/AggUnaryOp.java

Lines changed: 86 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.sysds.hops;
2121

2222
import org.apache.sysds.api.DMLScript;
23+
import org.apache.sysds.common.Types;
2324
import org.apache.sysds.common.Types.AggOp;
2425
import org.apache.sysds.common.Types.DataType;
2526
import org.apache.sysds.common.Types.Direction;
@@ -30,6 +31,7 @@
3031
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
3132
import org.apache.sysds.lops.Lop;
3233
import org.apache.sysds.common.Types.ExecType;
34+
import org.apache.sysds.lops.Nary;
3335
import org.apache.sysds.lops.PartialAggregate;
3436
import org.apache.sysds.lops.TernaryAggregate;
3537
import org.apache.sysds.lops.UAggOuterChain;
@@ -38,6 +40,8 @@
3840
import org.apache.sysds.runtime.meta.DataCharacteristics;
3941
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
4042

43+
import java.util.List;
44+
4145
// Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum
4246

4347
public class AggUnaryOp extends MultiThreadedHop
@@ -475,6 +479,17 @@ else if (binput1.getOp() == OpOp2.MULT ) {
475479
}
476480
}
477481
}
482+
if (input1.getParent().size() == 1
483+
&& input1 instanceof NaryOp) { //sum single consumer
484+
NaryOp nop = (NaryOp) input1;
485+
if(nop.getOp() == Types.OpOpN.MULT){
486+
List<Hop> inputsN = nop.getInput();
487+
if(inputsN.size() == 3){
488+
ret = HopRewriteUtils.isEqualSize(inputsN.get(0), inputsN.get(1)) &&
489+
HopRewriteUtils.isEqualSize(inputsN.get(1), inputsN.get(2));
490+
}
491+
}
492+
}
478493
}
479494
return ret;
480495
}
@@ -554,83 +569,91 @@ private boolean isUnaryAggregateOuterCPRewriteApplicable() {
554569

555570
private Lop constructLopsTernaryAggregateRewrite(ExecType et)
556571
{
557-
BinaryOp input1 = (BinaryOp)getInput().get(0);
558-
Hop input11 = input1.getInput().get(0);
559-
Hop input12 = input1.getInput().get(1);
560-
561572
Lop in1 = null, in2 = null, in3 = null;
562-
boolean handled = false;
563-
564-
if (input1.getOp() == OpOp2.POW) {
565-
assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
566-
in1 = input11.constructLops();
567-
in2 = in1;
568-
in3 = in1;
569-
handled = true;
570-
}
571-
else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, OpOp2.POW) ) {
572-
BinaryOp b11 = (BinaryOp)input11;
573-
switch( b11.getOp() ) {
574-
case MULT: // A*B*C case
575-
in1 = input11.getInput().get(0).constructLops();
576-
in2 = input11.getInput().get(1).constructLops();
577-
in3 = input12.constructLops();
578-
handled = true;
579-
break;
580-
case POW: // A*A*B case
581-
Hop b112 = b11.getInput().get(1);
582-
if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
583-
&& HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
584-
in1 = b11.getInput().get(0).constructLops();
585-
in2 = in1;
586-
in3 = input12.constructLops();
587-
handled = true;
588-
}
589-
break;
590-
default: break;
591-
}
592-
}
593-
else if( HopRewriteUtils.isBinary(input12, OpOp2.MULT, OpOp2.POW) ) {
594-
BinaryOp b12 = (BinaryOp)input12;
595-
switch (b12.getOp()) {
596-
case MULT: // A*B*C case
573+
Hop input = getInput().get(0);
574+
if(input instanceof BinaryOp) {
575+
BinaryOp input1 = (BinaryOp) input;
576+
Hop input11 = input1.getInput().get(0);
577+
Hop input12 = input1.getInput().get(1);
578+
579+
boolean handled = false;
580+
581+
if (input1.getOp() == OpOp2.POW) {
582+
assert (HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
597583
in1 = input11.constructLops();
598-
in2 = input12.getInput().get(0).constructLops();
599-
in3 = input12.getInput().get(1).constructLops();
584+
in2 = in1;
585+
in3 = in1;
600586
handled = true;
601-
break;
602-
case POW: // A*B*B case
603-
Hop b112 = b12.getInput().get(1);
604-
if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
605-
in1 = b12.getInput().get(0).constructLops();
606-
in2 = in1;
607-
in3 = input11.constructLops();
608-
handled = true;
587+
} else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, OpOp2.POW)) {
588+
BinaryOp b11 = (BinaryOp) input11;
589+
switch (b11.getOp()) {
590+
case MULT: // A*B*C case
591+
in1 = input11.getInput().get(0).constructLops();
592+
in2 = input11.getInput().get(1).constructLops();
593+
in3 = input12.constructLops();
594+
handled = true;
595+
break;
596+
case POW: // A*A*B case
597+
Hop b112 = b11.getInput().get(1);
598+
if (!(input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT)
599+
&& HopRewriteUtils.isLiteralOfValue(b112, 2)) {
600+
in1 = b11.getInput().get(0).constructLops();
601+
in2 = in1;
602+
in3 = input12.constructLops();
603+
handled = true;
604+
}
605+
break;
606+
default:
607+
break;
608+
}
609+
} else if (HopRewriteUtils.isBinary(input12, OpOp2.MULT, OpOp2.POW)) {
610+
BinaryOp b12 = (BinaryOp) input12;
611+
switch (b12.getOp()) {
612+
case MULT: // A*B*C case
613+
in1 = input11.constructLops();
614+
in2 = input12.getInput().get(0).constructLops();
615+
in3 = input12.getInput().get(1).constructLops();
616+
handled = true;
617+
break;
618+
case POW: // A*B*B case
619+
Hop b112 = b12.getInput().get(1);
620+
if (HopRewriteUtils.isLiteralOfValue(b112, 2)) {
621+
in1 = b12.getInput().get(0).constructLops();
622+
in2 = in1;
623+
in3 = input11.constructLops();
624+
handled = true;
625+
}
626+
break;
627+
default:
628+
break;
609629
}
610-
break;
611-
default: break;
612630
}
613-
}
614631

615-
if (!handled) {
616-
in1 = input11.constructLops();
617-
in2 = input12.constructLops();
618-
in3 = new LiteralOp(1).constructLops();
632+
if (!handled) {
633+
in1 = input11.constructLops();
634+
in2 = input12.constructLops();
635+
in3 = new LiteralOp(1).constructLops();
636+
}
637+
} else {
638+
NaryOp input1 = (NaryOp) input;
639+
in1 = input1.getInput().get(0).constructLops();
640+
in2 = input1.getInput().get(1).constructLops();
641+
in3 = input1.getInput().get(2).constructLops();
619642
}
620643

621-
//create new ternary aggregate operator
622-
int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads );
644+
//create new ternary aggregate operator
645+
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
623646
// The execution type of a unary aggregate instruction should depend on the execution type of inputs to avoid OOM
624647
// Since we only support matrix-vector and not vector-matrix, checking the execution type of input1 should suffice.
625-
ExecType et_input = input1.optFindExecType();
648+
ExecType et_input = input.optFindExecType();
626649
// Because ternary aggregate are not supported on GPU
627-
et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
650+
et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
628651
// If forced ExecType is FED, it means that the federated planner updated the ExecType and
629652
// execution may fail if ExecType is not FED
630653
et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED : et_input;
631-
632-
return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
633-
OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k);
654+
655+
return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
656+
OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k);
634657
}
635658

636659
@Override

src/main/java/org/apache/sysds/hops/NaryOp.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
200200
HopRewriteUtils.getSumValidInputNnz(dc, true));
201201
case MIN:
202202
case MAX:
203-
case PLUS: return new MatrixCharacteristics(
203+
case PLUS:
204+
case MULT: return new MatrixCharacteristics(
204205
HopRewriteUtils.getMaxInputDim(this, true),
205206
HopRewriteUtils.getMaxInputDim(this, false), -1, -1);
206207
case LIST:
@@ -230,6 +231,7 @@ public void refreshSizeInformation() {
230231
case MIN:
231232
case MAX:
232233
case PLUS:
234+
case MULT:
233235
setDim1(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, true));
234236
setDim2(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, false));
235237
break;

src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,8 +745,15 @@ public static LeftIndexingOp createLeftIndexingOp(Hop lhs, Hop rhs, Hop rl, Hop
745745

746746
public static NaryOp createNary(OpOpN op, Hop... inputs) {
747747
Hop mainInput = inputs[0];
748-
NaryOp nop = new NaryOp(mainInput.getName(), mainInput.getDataType(),
749-
mainInput.getValueType(), op, inputs);
748+
// safe for unordered inputs of Scalars and Matrices
749+
// e.g.: S*M*S = M
750+
// safe for Scalar with different value type
751+
// e.g.: Scalar(Int) * Scalar(FP64) = Scalar(FP64)
752+
boolean containsMatrix = Arrays.stream(inputs).anyMatch(Hop::isMatrix);
753+
boolean containsFP64 = Arrays.stream(inputs).anyMatch(h -> h.getValueType() == ValueType.FP64);
754+
DataType dtOut = containsMatrix ? DataType.MATRIX : mainInput.getDataType();
755+
ValueType vtOut = containsFP64? ValueType.FP64 : mainInput.getValueType();
756+
NaryOp nop = new NaryOp(mainInput.getName(), dtOut, vtOut, op, inputs);
750757
nop.setBlocksize(mainInput.getBlocksize());
751758
copyLineNumbers(mainInput, nop);
752759
nop.refreshSizeInformation();

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,8 +2801,8 @@ else if( HopRewriteUtils.isBasic1NSequence(second, first, true)
28012801

28022802
private static Hop foldMultipleMinMaxOperations(Hop hi)
28032803
{
2804-
if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, OpOp2.PLUS)
2805-
|| HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS))
2804+
if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, OpOp2.PLUS, OpOp2.MULT)
2805+
|| HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS, OpOpN.MULT))
28062806
&& hi.getValueType() != ValueType.STRING //exclude string concat
28072807
&& HopRewriteUtils.isNotMatrixVectorBinaryOperation(hi))
28082808
{
@@ -2839,7 +2839,7 @@ private static Hop foldMultipleMinMaxOperations(Hop hi)
28392839
for( Hop p : parents )
28402840
HopRewriteUtils.replaceChildReference(p, hi, hnew);
28412841
hi = hnew;
2842-
LOG.debug("Applied foldMultipleMinMaxPlusOperations (line "+hi.getBeginLine()+").");
2842+
LOG.debug("Applied foldMultipleMinMaxPlusMultOperations (line "+hi.getBeginLine()+").");
28432843
}
28442844
else {
28452845
converged = true;

src/main/java/org/apache/sysds/lops/Nary.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ private String getOpcode() {
117117
return "n"+operationType.name().toLowerCase();
118118
case PLUS:
119119
return "n+";
120+
case MULT:
121+
return "n*";
120122
default:
121123
throw new UnsupportedOperationException(
122124
"Nary operation type (" + operationType + ") is not defined.");

src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ public class CPInstructionParser extends InstructionParser {
180180
String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
181181
String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
182182
String2CPInstructionType.put( "n+" , CPType.BuiltinNary);
183+
String2CPInstructionType.put( "n*" , CPType.BuiltinNary);
183184

184185
String2CPInstructionType.put( "exp" , CPType.Unary);
185186
String2CPInstructionType.put( "abs" , CPType.Unary);

src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ public class SPInstructionParser extends InstructionParser
299299
String2SPInstructionType.put( "nmin", SPType.BuiltinNary);
300300
String2SPInstructionType.put( "nmax", SPType.BuiltinNary);
301301
String2SPInstructionType.put( "n+", SPType.BuiltinNary);
302+
String2SPInstructionType.put( "n*", SPType.BuiltinNary);
302303

303304
String2SPInstructionType.put( DataGen.RAND_OPCODE , SPType.Rand);
304305
String2SPInstructionType.put( DataGen.SEQ_OPCODE , SPType.Rand);

src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.sysds.common.Types.OpOpN;
2323
import org.apache.sysds.runtime.DMLRuntimeException;
2424
import org.apache.sysds.runtime.functionobjects.Builtin;
25+
import org.apache.sysds.runtime.functionobjects.Multiply;
2526
import org.apache.sysds.runtime.functionobjects.Plus;
2627
import org.apache.sysds.runtime.functionobjects.ValueFunction;
2728
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -85,6 +86,10 @@ else if( opcode.equals("n+") ) {
8586
return new MatrixBuiltinNaryCPInstruction(
8687
new SimpleOperator(Plus.getPlusFnObject()), opcode, str, outputOperand, inputOperands);
8788
}
89+
else if( opcode.equals("n*") ) {
90+
return new MatrixBuiltinNaryCPInstruction(
91+
new SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand, inputOperands);
92+
}
8893
else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) {
8994
return new EvalNaryCPInstruction(null, opcode, str, outputOperand, inputOperands);
9095
}

src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public void processInstruction(ExecutionContext ec) {
6262
outBlock = ((FrameBlock)outBlock).append(frames.get(i), cbind);
6363
}
6464
}
65-
else if( ArrayUtils.contains(new String[]{"nmin", "nmax", "n+"}, getOpcode()) ) {
65+
else if( ArrayUtils.contains(new String[]{"nmin", "nmax", "n+", "n*"}, getOpcode()) ) {
6666
outBlock = MatrixBlock.naryOperations(_optr, matrices.toArray(new MatrixBlock[0]),
6767
scalars.toArray(new ScalarObject[0]), new MatrixBlock());
6868
}

0 commit comments

Comments
 (0)