Skip to content

Commit 33cf89e

Browse files
committed
[SYSTEMDS-3839] Fix rewrite utils robustness (correct value types)
This patch fixes special cases of creating new unary and binary operators during rewrites, where the value types where not correctly set.
1 parent 1609470 commit 33cf89e

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ public boolean isNumeric() {
136136
}
137137
}
138138

139+
public boolean isFP() {
140+
return this==FP64 || this==FP32;
141+
}
142+
139143
/**
140144
* Helper method to detect Unknown ValueTypes.
141145
*

src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,13 @@ public static UnaryOp createUnary(Hop input, String type) {
622622
public static UnaryOp createUnary(Hop input, OpOp1 type) {
623623
DataType dt = type.isScalarOutput() ? DataType.SCALAR :
624624
(type==OpOp1.CAST_AS_MATRIX) ? DataType.MATRIX : input.getDataType();
625-
ValueType vt = (type==OpOp1.CAST_AS_MATRIX) ? ValueType.FP64 : input.getValueType();
625+
ValueType vt = input.getValueType();
626+
switch( type ) {
627+
case CAST_AS_MATRIX:
628+
case CAST_AS_DOUBLE: vt = ValueType.FP64; break;
629+
case CAST_AS_INT: vt = ValueType.INT64; break;
630+
case CAST_AS_BOOLEAN: vt = ValueType.BOOLEAN; break;
631+
}
626632
UnaryOp unary = new UnaryOp(input.getName(), dt, vt, type, input);
627633
unary.setBlocksize(input.getBlocksize());
628634
if( type.isScalarOutput() || type == OpOp1.CAST_AS_MATRIX ) {
@@ -650,11 +656,16 @@ public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op) {
650656
public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op, boolean outer) {
651657
Hop mainInput = input1.getDataType().isMatrix() ? input1 :
652658
input2.getDataType().isMatrix() ? input2 : input1;
659+
Hop otherInput = mainInput==input1 ? input2 : input1;
653660
BinaryOp bop = new BinaryOp(mainInput.getName(), mainInput.getDataType(),
654661
mainInput.getValueType(), op, input1, input2);
655-
//cleanup value type for relational operations
662+
//cleanup value type for relational operations and others
663+
if( otherInput.getValueType().isFP() && !mainInput.getValueType().isFP() )
664+
bop.setValueType(otherInput.getValueType());
656665
if( bop.isPPredOperation() && bop.getDataType().isScalar() )
657666
bop.setValueType(ValueType.BOOLEAN);
667+
if( bop.getDataType().isMatrix() )
668+
bop.setValueType(ValueType.FP64);
658669
bop.setOuterVectorOperation(outer);
659670
bop.setBlocksize(mainInput.getBlocksize());
660671
copyLineNumbers(mainInput, bop);
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.misc;
21+
22+
import static org.junit.Assert.assertEquals;
23+
24+
import org.apache.sysds.common.Types.DataType;
25+
import org.apache.sysds.common.Types.OpOp1;
26+
import org.apache.sysds.common.Types.OpOp2;
27+
import org.apache.sysds.common.Types.OpOpData;
28+
import org.apache.sysds.common.Types.ValueType;
29+
import org.apache.sysds.hops.DataOp;
30+
import org.apache.sysds.hops.Hop;
31+
import org.apache.sysds.hops.LiteralOp;
32+
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
33+
import org.junit.Test;
34+
35+
public class RewriteUtilsTest
36+
{
37+
@Test
38+
public void testUnaryValueTypes() {
39+
Hop input = new LiteralOp("true");
40+
41+
assertEquals(ValueType.FP64,
42+
HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_DOUBLE).getValueType());
43+
assertEquals(ValueType.INT64,
44+
HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_INT).getValueType());
45+
assertEquals(ValueType.BOOLEAN,
46+
HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_BOOLEAN).getValueType());
47+
}
48+
49+
@Test
50+
public void testBinaryValueTypes1() {
51+
Hop input1 = new LiteralOp(7d);
52+
Hop input2 = new DataOp("tmp", DataType.MATRIX, ValueType.INT64,
53+
OpOpData.TRANSIENTREAD, null, 3, 7, 21, 1000);
54+
assertEquals(ValueType.FP64,
55+
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT).getValueType());
56+
}
57+
58+
@Test
59+
public void testBinaryValueTypes2() {
60+
Hop input1 = new LiteralOp(7);
61+
Hop input2 = new DataOp("tmp", DataType.MATRIX, ValueType.INT64,
62+
OpOpData.TRANSIENTREAD, null, 3, 7, 21, 1000);
63+
assertEquals(ValueType.FP64,
64+
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT).getValueType());
65+
}
66+
67+
@Test
68+
public void testBinaryValueTypes3() {
69+
Hop input1 = new LiteralOp(7);
70+
Hop input2 = new LiteralOp(3);
71+
assertEquals(ValueType.INT64,
72+
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT).getValueType());
73+
}
74+
}

0 commit comments

Comments
 (0)