Skip to content

Commit c745550

Browse files
mike0609kingLou Frizzi Maria WagnerLaurits	Sartorius
authored andcommitted
[SYSTEMDS-3778] New determinant function, kernels, rewrites
DIA WiSe 24/25 project Closes #2196. Co-authored-by: Lou Frizzi Maria Wagner <[email protected]> Co-authored-by: Laurits Sartorius <[email protected]>
1 parent 0b5fae9 commit c745550

File tree

22 files changed

+909
-4
lines changed

22 files changed

+909
-4
lines changed

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ public enum Builtins {
113113
DECISIONTREEPREDICT("decisionTreePredict", true),
114114
DECOMPRESS("decompress", false),
115115
DEEPWALK("deepWalk", true),
116+
DET("det", false),
116117
DETECTSCHEMA("detectSchema", false),
117118
DENIALCONSTRAINTS("denialConstraints", true),
118119
DIFFERENCESTATISTICS("differenceStatistics", true),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ public enum Opcodes {
165165
STOP("stop", CPType.Unary),
166166
INVERSE("inverse", CPType.Unary),
167167
CHOLESKY("cholesky", CPType.Unary),
168+
DET("det", CPType.Unary),
168169
SPROP("sprop", CPType.Unary),
169170
SIGMOID("sigmoid", CPType.Unary),
170171
TYPEOF("typeOf", CPType.Unary),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ public enum OpOp1 {
539539
CAST_AS_FRAME, CAST_AS_LIST, CAST_AS_MATRIX, CAST_AS_SCALAR,
540540
CAST_AS_BOOLEAN, CAST_AS_DOUBLE, CAST_AS_INT,
541541
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
542-
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
542+
CUMSUMPROD, DET, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
543543
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
544544
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
545545
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
@@ -558,6 +558,7 @@ public enum OpOp1 {
558558

559559
public boolean isScalarOutput() {
560560
return this == CAST_AS_SCALAR
561+
|| this == DET
561562
|| this == NROW || this == NCOL
562563
|| this == LENGTH || this == EXISTS
563564
|| this == IQM || this == LINEAGE
@@ -579,6 +580,7 @@ public String toString() {
579580
case CUMPROD: return Opcodes.UCUMM.toString();
580581
case CUMSUM: return Opcodes.UCUMKP.toString();
581582
case CUMSUMPROD: return Opcodes.UCUMKPM.toString();
583+
case DET: return Opcodes.DET.toString();
582584
case DETECTSCHEMA: return Opcodes.DETECTSCHEMA.toString();
583585
case MULT2: return Opcodes.MULT2.toString();
584586
case NOT: return Opcodes.NOT.toString();

src/main/java/org/apache/sysds/hops/UnaryOp.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,8 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent
512512

513513
//ensure cp exec type for single-node operations
514514
if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == OpOp1.STOP || _op == OpOp1.TYPEOF
515-
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
515+
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.DET
516+
||_op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
516517
|| getInput().get(0).getDataType() == DataType.LIST || isMetadataOperation() )
517518
{
518519
_etype = ExecType.CP;

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,15 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
159159
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
160160
hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
161161
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
162+
hi = simplifyTransposeInDetOperation(hop, hi, i); //e.g., det(t(X)) -> det(X)
162163
hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
163164
hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X)
164165
hi = removeUnnecessaryAggregates(hi); //e.g., sum(rowSums(X)) -> sum(X)
165166
hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
166167
hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X))
167168
hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
169+
hi = pushdownDetMultOperation(hop, hi, i); //e.g., det(X%*%Y) -> det(X)*det(Y)
170+
hi = pushdownDetScalarMatrixMultOperation(hop, hi, i); //e.g., det(lambda*X) -> lambda^nrow(X)*det(X)
168171
hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lambda*X) -> lambda*sum(X)
169172
hi = pullupAbs(hop, hi, i); //e.g., abs(X)*abs(Y) --> abs(X*Y)
170173
hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
@@ -922,6 +925,29 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int
922925
return hi;
923926
}
924927

928+
/**
929+
* det(t(X)) -> det(X)
930+
*
931+
* @param parent parent high-level operator
932+
* @param hi high-level operator
933+
* @param pos position
934+
* @return high-level operator
935+
*/
936+
private static Hop simplifyTransposeInDetOperation(Hop parent, Hop hi, int pos)
937+
{
938+
if(HopRewriteUtils.isUnary(hi, OpOp1.DET)
939+
&& HopRewriteUtils.isReorg(hi.getInput(0), ReOrgOp.TRANS))
940+
{
941+
Hop operand = hi.getInput(0).getInput(0);
942+
Hop uop = HopRewriteUtils.createUnary(operand, OpOp1.DET);
943+
HopRewriteUtils.replaceChildReference(parent, hi, uop, pos);
944+
945+
LOG.debug("Applied simplifyTransposeInDetOperation.");
946+
return uop;
947+
}
948+
return hi;
949+
}
950+
925951
/**
926952
* t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v)
927953
* t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
@@ -1163,6 +1189,65 @@ private static Hop pushdownCSETransposeScalarOperation( Hop parent, Hop hi, int
11631189
return hi;
11641190
}
11651191

1192+
/**
1193+
* det(X%*%Y) -> det(X)*det(Y)
1194+
*
1195+
* @param parent parent high-level operator
1196+
* @param hi high-level operator
1197+
* @param pos position
1198+
* @return high-level operator
1199+
*/
1200+
private static Hop pushdownDetMultOperation(Hop parent, Hop hi, int pos) {
1201+
if( HopRewriteUtils.isUnary(hi, OpOp1.DET)
1202+
&& HopRewriteUtils.isMatrixMultiply(hi.getInput(0))
1203+
&& hi.getInput(0).getInput(0).isMatrix()
1204+
&& hi.getInput(0).getInput(1).isMatrix())
1205+
{
1206+
Hop operand1 = hi.getInput(0).getInput(0);
1207+
Hop operand2 = hi.getInput(0).getInput(1);
1208+
Hop uop1 = HopRewriteUtils.createUnary(operand1, OpOp1.DET);
1209+
Hop uop2 = HopRewriteUtils.createUnary(operand2, OpOp1.DET);
1210+
Hop bop = HopRewriteUtils.createBinary(uop1, uop2, OpOp2.MULT);
1211+
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
1212+
1213+
LOG.debug("Applied pushdownDetMultOperation.");
1214+
return bop;
1215+
}
1216+
return hi;
1217+
}
1218+
1219+
/**
1220+
* det(lambda*X) -> lambda^nrow*det(X)
1221+
*
1222+
* @param parent parent high-level operator
1223+
* @param hi high-level operator
1224+
* @param pos position
1225+
* @return high-level operator
1226+
*/
1227+
private static Hop pushdownDetScalarMatrixMultOperation(Hop parent, Hop hi, int pos) {
1228+
if( HopRewriteUtils.isUnary(hi, OpOp1.DET)
1229+
&& HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.MULT)
1230+
&& ((hi.getInput(0).getInput(0).isMatrix() && hi.getInput(0).getInput(1).isScalar())
1231+
|| (hi.getInput(0).getInput(0).isScalar() && hi.getInput(0).getInput(1).isMatrix())))
1232+
{
1233+
Hop operand1 = hi.getInput(0).getInput(0);
1234+
Hop operand2 = hi.getInput(0).getInput(1);
1235+
1236+
Hop lambda = (operand1.isScalar()) ? operand1 : operand2;
1237+
Hop matrix = (operand1.isMatrix()) ? operand1 : operand2;
1238+
1239+
Hop uopDet = HopRewriteUtils.createUnary(matrix, OpOp1.DET);
1240+
Hop uopNrow = HopRewriteUtils.createUnary(matrix, OpOp1.NROW);
1241+
Hop bopPow = HopRewriteUtils.createBinary(lambda, uopNrow, OpOp2.POW);
1242+
Hop bopMult = HopRewriteUtils.createBinary(bopPow, uopDet, OpOp2.MULT);
1243+
HopRewriteUtils.replaceChildReference(parent, hi, bopMult, pos);
1244+
1245+
LOG.debug("Applied pushdownDetScalarMatrixMultOperation.");
1246+
return bopMult;
1247+
}
1248+
return hi;
1249+
}
1250+
11661251
private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) {
11671252
//pattern: sum(lamda*X) -> lamda*sum(X)
11681253
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,17 @@ else if( getOpCode() == Builtins.RBIND ) {
13021302
output.setBlocksize(id.getBlocksize());
13031303
output.setValueType(id.getValueType());
13041304
break;
1305+
case DET:
1306+
checkNumParameters(1);
1307+
checkMatrixParam(getFirstExpr());
1308+
if ( id.getDim2() == -1 || id.getDim1() != id.getDim2() ) {
1309+
raiseValidateError("det requires a square matrix as first argument.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
1310+
}
1311+
output.setDataType(DataType.SCALAR);
1312+
output.setDimensions(0, 0);
1313+
output.setBlocksize(0);
1314+
output.setValueType(ValueType.FP64);
1315+
break;
13051316
case NROW:
13061317
case NCOL:
13071318
case LENGTH:

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,6 +2755,7 @@ else if ( in.length == 2 )
27552755
case SQRT_MATRIX_JAVA:
27562756
case CHOLESKY:
27572757
case TYPEOF:
2758+
case DET:
27582759
case DETECTSCHEMA:
27592760
case COLNAMES:
27602761
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),

src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ public static long getInstNFLOP(
501501
case "cholesky":
502502
costs = (1.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity();
503503
break;
504+
case "det":
504505
case "detectschema":
505506
case "colnames":
506507
throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet");

src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryMatrixCPInstruction.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.sysds.runtime.instructions.cp;
2121

22+
import org.apache.sysds.common.Types.ValueType;
2223
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
2324
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
2425
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -57,7 +58,13 @@ public void processInstruction(ExecutionContext ec) {
5758
LineageItem lin = (!inObj.hasValidLineage() || !inObj.getCacheLineage().isLeaf() ||
5859
CacheableData.isBelowCachingThreshold(retBlock)) ? null :
5960
getCacheLineageItem(inObj.getCacheLineage());
60-
ec.setMatrixOutputAndLineage(output, retBlock, lin);
61+
if (getOpcode().equals("det")){
62+
var temp = ScalarObjectFactory.createScalarObject(ValueType.FP64, retBlock.get(0,0));
63+
ec.setVariable(output.getName(), temp);
64+
}
65+
else {
66+
ec.setMatrixOutputAndLineage(output, retBlock, lin);
67+
}
6168
}
6269

6370
public LineageItem getCacheLineageItem(LineageItem input) {

0 commit comments

Comments
 (0)