Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE
hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
hi = simplifyReverseSequence(hop, hi, i); //e.g., rev(seq(1,n)) -> seq(n,1)
hi = simplifyReverseSequenceStep(hop, hi, i);
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
Expand Down Expand Up @@ -209,6 +210,59 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
hop.setVisited();
}

private static Hop simplifyReverseSequenceStep(Hop parent, Hop hi, int pos) {
if (HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
&& hi.getInput(0) instanceof DataGenOp
&& ((DataGenOp) hi.getInput(0)).getOp() == OpOpDG.SEQ
&& hi.getInput(0).getParent().size() == 1) { // only one consumer

DataGenOp seq = (DataGenOp) hi.getInput(0);
Hop from = seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
Hop to = seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
Hop incr = seq.getInput().get(seq.getParamIndex(Statement.SEQ_INCR));

if (from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp) {
double fromVal = ((LiteralOp) from).getDoubleValue();
double toVal = ((LiteralOp) to).getDoubleValue();
double incrVal = ((LiteralOp) incr).getDoubleValue();

// Skip if increment is zero (invalid sequence)
if (Math.abs(incrVal) < 1e-10)
return hi;

boolean isValidDirection = false;

// Checking direction compatibility
if ((incrVal > 0 && fromVal <= toVal) || (incrVal < 0 && fromVal >= toVal)) {
isValidDirection = true;
}

if (isValidDirection) {
// Calculate the number of elements and the last element
int numValues = (int)Math.floor(Math.abs((toVal - fromVal) / incrVal)) + 1;
double lastVal = fromVal + (numValues - 1) * incrVal;

// Create a new sequence based on actual last value
LiteralOp newFrom = new LiteralOp(lastVal);
LiteralOp newTo = new LiteralOp(fromVal);
LiteralOp newIncr = new LiteralOp(-incrVal);

// Replace the parameters
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), newFrom);
seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), newTo);
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), newIncr);

// Replace the old sequence with the new one
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi, seq);
hi = seq;
LOG.debug("Applied simplifyReverseSequenceStep (line " + hi.getBeginLine() + ").");
}
}
}
return hi;
}

private static Hop removeUnnecessaryVectorizeOperation(Hop hi)
{
//applies to all binary matrix operations, if one input is unnecessarily vectorized
Expand Down Expand Up @@ -1853,6 +1907,37 @@ private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
LOG.debug("Applied removeUnecessaryReorgOperation.");
}
}
// Handle the second case: t(X) %*% v -> t(t(v) %*% X)
else if (hi instanceof BinaryOp && ((BinaryOp) hi).getOp() == OpOp2.MULT) {
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);

if (left instanceof ReorgOp && ((ReorgOp) left).getOp() == ReOrgOp.TRANS) {
try {
Hop X = left.getInput().get(0);

// Create transpose of v
Hop transposeV = HopRewriteUtils.createTranspose(right);

// Create multiplication
Hop newMult = HopRewriteUtils.createMatrixMultiply(transposeV, X);

// Create final transpose
Hop finalTranspose = HopRewriteUtils.createTranspose(newMult);

// Replace the original hop with new construct
HopRewriteUtils.replaceChildReference(parent, hi, finalTranspose, pos);
HopRewriteUtils.cleanupUnreferenced(hi);

LOG.debug("Applied removeUnnecessaryReorgOperation.");

return finalTranspose;
}
catch (Exception e) {
LOG.error("Failed to apply removeUnnecessaryReorgOperation: " + e.getMessage(), e);
}
}
}

return hi;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.rewrite;

import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase {
private static final String TEST_NAME1 = "RewriteSimplifyReverseSequenceStep";

private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
}

@Test
public void testRewriteReverseSeqStep() {
testRewriteReverseSeq(TEST_NAME1, true);
}

@Test
public void testNoRewriteReverseSeqStep() {
testRewriteReverseSeq(TEST_NAME1, false);
}

private void testRewriteReverseSeq(String testname, boolean rewrites) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
int rows = 10;

try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-stats", "-args", String.valueOf(rows), output("Scalar")};
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;

runTest(true, false, null, -1);

// Calculate expected sums for each sequence
double sum1 = calculateSum(0, rows-1, 1); // A1 = rev(seq(0, rows-1, 1))
double sum2 = calculateSum(0, rows, 2); // A2 = rev(seq(0, rows, 2))
double sum3 = calculateSum(2, rows, 2); // A3 = rev(seq(2, rows, 2))
double sum4 = calculateSum(0, 100, 5); // A4 = rev(seq(0, 100, 5))
double sum5 = calculateSum(15, 5, -0.5); // A5 = rev(seq(15, 5, -0.5))

double expected = sum1 + sum2 + sum3 + sum4 + sum5;

double ret = readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1, 1)).doubleValue();

Assert.assertEquals("Incorrect sum computed", expected, ret, 1e-10);

if (rewrites) {
// With bidirectional rewrite, REV operations should be removed
Assert.assertFalse("Rewrite should have removed REV operation!",
heavyHittersContainsString("rev"));
}
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
}

// Helper method to calculate sum of a sequence
private double calculateSum(double from, double to, double incr) {
double sum = 0;
int n = 0;

if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) {
// Calculate number of elements in the sequence
n = (int)Math.floor(Math.abs((to - from) / incr)) + 1;

// Calculate the last element in the sequence
double last = from + (n - 1) * incr;

// Use arithmetic sequence sum formula: n * (first + last) / 2
sum = n * (from + last) / 2;
}

return sum;
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

rows = as.integer($1)

# Original test sequences (positive increments)
A1 = rev(seq(0, rows-1, 1)) # Should become seq(rows-1, 0, -1)
A2 = rev(seq(0, rows, 2)) # Should become seq(rows, 0, -2)
A3 = rev(seq(2, rows, 2)) # Should become seq(lastVal, 2, -2) where lastVal is the last value in the sequence
A4 = rev(seq(0, 100, 5)) # Should become seq(100, 0, -5)
A5 = rev(seq(15, 5, -0.5)) # Should become seq(5, 15, 0.5)

# Sum all sequences
R = sum(A1) + sum(A2) + sum(A3) + sum(A4) + sum(A5)

# Output
write(R, $2)
Loading