Skip to content

Commit cd16f7a

Browse files
committed
[SYSTEMDS-3837] Fix trace error handling (only squared matrices)
The trace of a matrix is only defined for squared matrices and our kernels also assume that internally. However, there was on systematic error handling leading to some invalid invocations failing with index-out-of-bounds while others succeeded.
1 parent 33cf89e commit cd16f7a

File tree

4 files changed

+73
-10
lines changed

4 files changed

+73
-10
lines changed

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,9 +942,15 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
942942
output.setBlocksize (id.getBlocksize());
943943
output.setValueType(id.getValueType());
944944
break;
945+
case TRACE:
946+
if(getFirstExpr().getOutput().dimsKnown()
947+
&& getFirstExpr().getOutput().getDim1() != getFirstExpr().getOutput().getDim2())
948+
{
949+
raiseValidateError("Trace is only defined on squared matrices but found ["
950+
+getFirstExpr().getOutput().getDim1()+"x"+getFirstExpr().getOutput().getDim2()+"].", conditional);
951+
}
945952
case SUM:
946953
case PROD:
947-
case TRACE:
948954
case SD:
949955
case VAR:
950956
// sum(X);

src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,20 @@ public class TraceTest extends AutomatedTestBase {
4444

4545
private final static String TEST_DIR = "functions/aggregate/";
4646
private static final String TEST_CLASS_DIR = TEST_DIR + TraceTest.class.getSimpleName() + "/";
47-
private final static String TEST_GENERAL = "General";
48-
private final static String TEST_SCALAR = "Scalar";
47+
private final static String TEST_GENERAL = "TraceTest";
48+
private final static String TEST_SCALAR = "TraceScalarTest";
49+
private final static String TEST_INVALID1 = "TraceInvalid1";
50+
private final static String TEST_INVALID2 = "TraceInvalid2";
4951

5052
@Override
5153
public void setUp() {
5254
// positive tests
53-
addTestConfiguration(TEST_GENERAL, new TestConfiguration(TEST_CLASS_DIR, "TraceTest", new String[] {"b"}));
55+
addTestConfiguration(TEST_GENERAL, new TestConfiguration(TEST_CLASS_DIR, TEST_GENERAL, new String[] {"b"}));
5456

5557
// negative tests
56-
addTestConfiguration(TEST_SCALAR, new TestConfiguration(TEST_CLASS_DIR, "TraceScalarTest", new String[] {"b"}));
58+
addTestConfiguration(TEST_SCALAR, new TestConfiguration(TEST_CLASS_DIR, TEST_SCALAR, new String[] {"b"}));
59+
addTestConfiguration(TEST_INVALID1, new TestConfiguration(TEST_CLASS_DIR, TEST_INVALID1, new String[] {"b"}));
60+
addTestConfiguration(TEST_INVALID2, new TestConfiguration(TEST_CLASS_DIR, TEST_INVALID2, new String[] {"b"}));
5761
}
5862

5963
@Test
@@ -85,16 +89,25 @@ public void testGeneral() {
8589

8690
@Test
8791
public void testScalar() {
88-
int scalar = 12;
89-
9092
TestConfiguration config = getTestConfiguration(TEST_SCALAR);
91-
config.addVariable("scalar", scalar);
93+
config.addVariable("scalar", 12);
9294

9395
createHelperMatrix();
94-
9596
loadTestConfiguration(config);
96-
97+
runTest(true, LanguageException.class);
98+
}
99+
100+
@Test
101+
public void testInvalid1() {
102+
TestConfiguration config = getTestConfiguration(TEST_INVALID1);
103+
loadTestConfiguration(config);
97104
runTest(true, LanguageException.class);
98105
}
99106

107+
@Test
108+
public void testInvalid2() {
109+
TestConfiguration config = getTestConfiguration(TEST_INVALID2);
110+
loadTestConfiguration(config);
111+
runTest(true, LanguageException.class);
112+
}
100113
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
print(trace(rand(rows=100,cols=10)))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
print(trace(rand(rows=1000,cols=10)))

0 commit comments

Comments
 (0)