Skip to content

Commit 61afba5

Browse files
committed
[SYSTEMDS-3888] Fix size propagation over unique operations
This patch fixes the incorrect size propagation of unique which led to incorrect results if the dimensions are used in subsequent ops. Thanks to Chi-Hsin Huang for catching this bug. Furthermore, this patch also includes minor updates for code quality (removed unused imports, annotated unused functions)
1 parent af73e38 commit 61afba5

File tree

6 files changed

+88
-21
lines changed

6 files changed

+88
-21
lines changed

src/main/java/org/apache/sysds/hops/AggUnaryOp.java

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,18 @@ protected DataCharacteristics inferOutputCharacteristics( MemoTable memo ) {
323323
DataCharacteristics ret = null;
324324
Hop input = getInput().get(0);
325325
DataCharacteristics dc = memo.getAllInputStats(input);
326-
if( _direction == Direction.Col && dc.colsKnown() )
327-
ret = new MatrixCharacteristics(1, dc.getCols(), -1, -1);
328-
else if( _direction == Direction.Row && dc.rowsKnown() )
329-
ret = new MatrixCharacteristics(dc.getRows(), 1, -1, -1);
326+
if( _op == AggOp.UNIQUE ) {
327+
if( _direction == Direction.RowCol && dc.rowsKnown() )
328+
ret = new MatrixCharacteristics(dc.getRows(), 1, -1, -1);
329+
else
330+
ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, -1);
331+
}
332+
else {
333+
if( _direction == Direction.Col && dc.colsKnown() )
334+
ret = new MatrixCharacteristics(1, dc.getCols(), -1, -1);
335+
else if( _direction == Direction.Row && dc.rowsKnown() )
336+
ret = new MatrixCharacteristics(dc.getRows(), 1, -1, -1);
337+
}
330338
return ret;
331339
}
332340

@@ -648,9 +656,23 @@ private Lop constructLopsTernaryAggregateRewrite(ExecType et)
648656
@Override
649657
public void refreshSizeInformation()
650658
{
651-
if (getDataType() != DataType.SCALAR)
652-
{
653-
Hop input = getInput().get(0);
659+
Hop input = getInput().get(0);
660+
if( _op == AggOp.UNIQUE ) {
661+
if ( _direction == Direction.Col ) {
662+
setDim1(-1); //unknown num unique
663+
setDim2(input.getDim2());
664+
}
665+
else if ( _direction == Direction.Row ) {
666+
setDim1(input.getDim1());
667+
setDim2(-1); //unknown num unique
668+
}
669+
else {
670+
setDim1(-1);
671+
setDim2(1);
672+
}
673+
}
674+
//general case: all other unary aggregations
675+
else if (getDataType() != DataType.SCALAR) {
654676
if ( _direction == Direction.Col ) //colwise computations
655677
{
656678
setDim1(1);

src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ public EstimatorLayeredGraph(int rounds) {
5757

5858
@Override
5959
public DataCharacteristics estim(MMNode root) {
60-
List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>());
61-
List<OpCode> ops = getOps(root, new ArrayList<>());
62-
List<LayeredGraph> LGs = new ArrayList<>();
60+
//List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>());
61+
//List<OpCode> ops = getOps(root, new ArrayList<>());
62+
//List<LayeredGraph> LGs = new ArrayList<>();
6363
LayeredGraph ret = traverse(root);
6464
long nnz = ret.estimateNnz();
6565
return root.setDataCharacteristics(new MatrixCharacteristics(
@@ -125,6 +125,7 @@ private static LayeredGraph estimInternal(LayeredGraph lg1, LayeredGraph lg2, Op
125125
}
126126
}
127127

128+
@SuppressWarnings("unused")
128129
private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
129130
//NOTE: this extraction is only correct and efficient for chains, no DAGs
130131
if( node.isLeaf() )
@@ -136,6 +137,7 @@ private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
136137
return leafs;
137138
}
138139

140+
@SuppressWarnings("unused")
139141
private List<OpCode> getOps(MMNode node, List<OpCode> ops) {
140142
//NOTE: this extraction is only correct and efficient for chains, no DAGs
141143
if(node.isLeaf()) {

src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
import org.apache.sysds.common.Types.OpOp1;
2828
import org.apache.sysds.common.Types.OpOp2;
2929
import org.apache.sysds.hops.UnaryOp;
30-
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
31-
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
3230
import org.apache.sysds.hops.BinaryOp;
3331

3432
import org.apache.sysds.common.Types.DataType;

src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -562,24 +562,26 @@ private void validateUnique(DataIdentifier output, boolean conditional) {
562562

563563
private void validateUniqueAggregationDirection(Identifier dataId, DataIdentifier output) {
564564
HashMap<String, Expression> varParams = getVarParams();
565+
String inputDirection = Types.Direction.RowCol.toString();
565566
if (varParams.containsKey("dir")) {
566-
String inputDirectionString = varParams.get("dir").toString().toUpperCase();
567-
567+
inputDirection = varParams.get("dir").toString().toUpperCase();
568568
// unrecognized value for "dir" parameter
569-
if (!inputDirectionString.equals(Types.Direction.Row.toString())
570-
&& !inputDirectionString.equals(Types.Direction.Col.toString())
571-
&& !inputDirectionString.equals(Types.Direction.RowCol.toString())) {
572-
raiseValidateError("Invalid argument: " + inputDirectionString + " is not recognized");
569+
if (!inputDirection.equals(Types.Direction.Row.toString())
570+
&& !inputDirection.equals(Types.Direction.Col.toString())
571+
&& !inputDirection.equals(Types.Direction.RowCol.toString())) {
572+
raiseValidateError("Invalid argument: " + inputDirection + " is not recognized");
573573
}
574574
}
575575

576-
// rc/r/c -> unique return value is the same as the input in the worst case
577576
// default to dir="rc"
578577
output.setDataType(DataType.MATRIX);
579-
output.setDimensions(dataId.getDim1(), dataId.getDim2());
578+
output.setDimensions(
579+
inputDirection.equals(Types.Direction.Row.toString()) ? dataId.getDim1() : -1,
580+
inputDirection.equals(Types.Direction.Col.toString()) ? dataId.getDim2() :
581+
inputDirection.equals(Types.Direction.RowCol.toString()) ? 1 : -1);
580582
output.setBlocksize(dataId.getBlocksize());
581583
output.setValueType(ValueType.FP64);
582-
output.setNnz(dataId.getNnz());
584+
output.setNnz(-1);
583585
}
584586

585587
private void checkStringParam(boolean optional, String fname, String pname, boolean conditional) {

src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.sysds.test.AutomatedTestBase;
2828
import org.apache.sysds.test.TestConfiguration;
2929
import org.apache.sysds.test.TestUtils;
30+
import org.apache.sysds.utils.Statistics;
3031
import org.junit.Assert;
3132

3233
import java.util.HashMap;
@@ -38,6 +39,7 @@ public class SizePropagationTest extends AutomatedTestBase
3839
private static final String TEST_NAME3 = "SizePropagationLoopIx2";
3940
private static final String TEST_NAME4 = "SizePropagationLoopIx3";
4041
private static final String TEST_NAME5 = "SizePropagationLoopIx4";
42+
private static final String TEST_NAME6 = "SizePropagationUnique";
4143

4244
private static final String TEST_DIR = "functions/misc/";
4345
private static final String TEST_CLASS_DIR = TEST_DIR + SizePropagationTest.class.getSimpleName() + "/";
@@ -52,6 +54,7 @@ public void setUp() {
5254
addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
5355
addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
5456
addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
57+
addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
5558
}
5659

5760
@Test
@@ -104,6 +107,16 @@ public void testSizePropagationLoopIx4Rewrites() {
104107
testSizePropagation( TEST_NAME5, true, N );
105108
}
106109

110+
@Test
111+
public void testSizePropagationUnique1() {
112+
testSizePropagation( TEST_NAME6, false, 10 );
113+
}
114+
115+
@Test
116+
public void testSizePropagationUnique2() {
117+
testSizePropagation( TEST_NAME6, false, 10 );
118+
}
119+
107120
private void testSizePropagation( String testname, boolean rewrites, int expect ) {
108121
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
109122
ExecMode oldPlatform = rtplatform;
@@ -122,6 +135,8 @@ private void testSizePropagation( String testname, boolean rewrites, int expect
122135
runTest(true, false, null, -1);
123136
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
124137
Assert.assertEquals(Double.valueOf(expect), dmlfile.get(new CellIndex(1,1)));
138+
if( testname.equals(TEST_NAME6) )
139+
Assert.assertEquals(0, Statistics.getNoOfCompiledSPInst());
125140
}
126141
finally {
127142
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
A = matrix("1 2 3 4 5 6 7", rows=7,cols=1)
23+
B = matrix("4 5 6 7 8 9 10", rows=7,cols=1)
24+
C = rbind(A,B)
25+
D = unique(C)
26+
n = nrow(D);
27+
R = as.matrix(n);
28+
write(R, $2);

0 commit comments

Comments
 (0)