|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package org.apache.sysds.test.functions.rewrite; |
| 21 | + |
| 22 | +import org.apache.sysds.hops.OptimizerUtils; |
| 23 | +import org.apache.sysds.runtime.matrix.data.MatrixValue; |
| 24 | +import org.apache.sysds.test.AutomatedTestBase; |
| 25 | +import org.apache.sysds.test.TestConfiguration; |
| 26 | +import org.apache.sysds.test.TestUtils; |
| 27 | +import org.junit.Assert; |
| 28 | +import org.junit.Test; |
| 29 | + |
| 30 | +public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase { |
| 31 | + private static final String TEST_NAME1 = "RewriteSimplifyReverseSequenceStep"; |
| 32 | + |
| 33 | + private static final String TEST_DIR = "functions/rewrite/"; |
| 34 | + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/"; |
| 35 | + |
| 36 | + @Override |
| 37 | + public void setUp() { |
| 38 | + TestUtils.clearAssertionInformation(); |
| 39 | + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); |
| 40 | + } |
| 41 | + |
| 42 | + @Test |
| 43 | + public void testRewriteReverseSeqStep() { |
| 44 | + testRewriteReverseSeq(TEST_NAME1, true); |
| 45 | + } |
| 46 | + |
| 47 | + @Test |
| 48 | + public void testNoRewriteReverseSeqStep() { |
| 49 | + testRewriteReverseSeq(TEST_NAME1, false); |
| 50 | + } |
| 51 | + |
| 52 | + private void testRewriteReverseSeq(String testname, boolean rewrites) { |
| 53 | + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; |
| 54 | + int rows = 10; |
| 55 | + |
| 56 | + try { |
| 57 | + TestConfiguration config = getTestConfiguration(testname); |
| 58 | + loadTestConfiguration(config); |
| 59 | + |
| 60 | + String HOME = SCRIPT_DIR + TEST_DIR; |
| 61 | + fullDMLScriptName = HOME + testname + ".dml"; |
| 62 | + programArgs = new String[]{"-stats", "-args", String.valueOf(rows), output("Scalar")}; |
| 63 | + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; |
| 64 | + |
| 65 | + runTest(true, false, null, -1); |
| 66 | + |
| 67 | + // Calculate expected sums for each sequence |
| 68 | + double sum1 = calculateSum(0, rows-1, 1); // A1 = rev(seq(0, rows-1, 1)) |
| 69 | + double sum2 = calculateSum(0, rows, 2); // A2 = rev(seq(0, rows, 2)) |
| 70 | + double sum3 = calculateSum(2, rows, 2); // A3 = rev(seq(2, rows, 2)) |
| 71 | + double sum4 = calculateSum(0, 100, 5); // A4 = rev(seq(0, 100, 5)) |
| 72 | + double sum5 = calculateSum(15, 5, -0.5); // A5 = rev(seq(15, 5, -0.5)) |
| 73 | + |
| 74 | + double expected = sum1 + sum2 + sum3 + sum4 + sum5; |
| 75 | + |
| 76 | + double ret = readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1, 1)).doubleValue(); |
| 77 | + |
| 78 | + Assert.assertEquals("Incorrect sum computed", expected, ret, 1e-10); |
| 79 | + |
| 80 | + if (rewrites) { |
| 81 | + // With bidirectional rewrite, REV operations should be removed |
| 82 | + Assert.assertFalse("Rewrite should have removed REV operation!", |
| 83 | + heavyHittersContainsString("rev")); |
| 84 | + } |
| 85 | + } |
| 86 | + finally { |
| 87 | + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + // Helper method to calculate sum of a sequence |
| 92 | + private double calculateSum(double from, double to, double incr) { |
| 93 | + double sum = 0; |
| 94 | + int n = 0; |
| 95 | + |
| 96 | + if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) { |
| 97 | + // Calculate number of elements in the sequence |
| 98 | + n = (int)Math.floor(Math.abs((to - from) / incr)) + 1; |
| 99 | + |
| 100 | + // Calculate the last element in the sequence |
| 101 | + double last = from + (n - 1) * incr; |
| 102 | + |
| 103 | + // Use arithmetic sequence sum formula: n * (first + last) / 2 |
| 104 | + sum = n * (from + last) / 2; |
| 105 | + } |
| 106 | + |
| 107 | + return sum; |
| 108 | + } |
| 109 | +} |
0 commit comments