|
40 | 40 | public class FederatedCentralMomentTest extends AutomatedTestBase { |
41 | 41 |
|
42 | 42 | private final static String TEST_DIR = "functions/federated/"; |
43 | | - private final static String TEST_NAME = "FederatedCentralMomentTest"; |
| 43 | + private final static String TEST_NAME1 = "FederatedCentralMomentTest"; |
| 44 | + private final static String TEST_NAME2 = "FederatedCentralMomentWeightedTest"; |
44 | 45 | private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCentralMomentTest.class.getSimpleName() + "/"; |
45 | 46 |
|
46 | 47 | private final static int blocksize = 1024; |
47 | 48 | @Parameterized.Parameter() |
48 | 49 | public int rows; |
49 | 50 |
|
50 | 51 | @Parameterized.Parameter(1) |
| 52 | + public int cols; |
| 53 | + |
| 54 | + @Parameterized.Parameter(2) |
51 | 55 | public int k; |
52 | 56 |
|
53 | 57 | @Parameterized.Parameters |
54 | 58 | public static Collection<Object[]> data() { |
55 | | - return Arrays.asList(new Object[][] {{1000, 2}, {1000, 3}, {1000, 4}}); |
| 59 | + return Arrays.asList(new Object[][] {{20, 1, 2}}); |
56 | 60 | } |
57 | 61 |
|
58 | 62 | @Override |
59 | 63 | public void setUp() { |
60 | 64 | TestUtils.clearAssertionInformation(); |
61 | | - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"})); |
| 65 | + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"})); |
| 66 | + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"})); |
62 | 67 | } |
63 | 68 |
|
64 | 69 | @Test |
65 | | - @Ignore // infinite runtime online but works locally. |
66 | 70 | public void federatedCentralMomentCP() { |
67 | | - federatedCentralMoment(Types.ExecMode.SINGLE_NODE); |
| 71 | + federatedCentralMoment(Types.ExecMode.SINGLE_NODE, false); |
| 72 | + } |
| 73 | + |
| 74 | + @Test |
| 75 | + public void federatedCentralMomentWeightedCP() { |
| 76 | + federatedCentralMoment(Types.ExecMode.SINGLE_NODE, true); |
68 | 77 | } |
69 | 78 |
|
70 | 79 | @Test |
71 | | - @Ignore |
72 | 80 | public void federatedCentralMomentSP() { |
73 | | - federatedCentralMoment(Types.ExecMode.SPARK); |
| 81 | + federatedCentralMoment(Types.ExecMode.SPARK, false); |
| 82 | + } |
| 83 | + |
| 84 | + // The test fails due to an error while executing rmvar instruction after cm calculation |
| 85 | + // The CacheStatus of the weights variable is READ hence it can't be modified |
| 86 | + // In this test the input matrix is federated and weights are read from file |
| 87 | + @Ignore |
| 88 | + @Test |
| 89 | + public void federatedCentralMomentWeightedSP() { |
| 90 | + federatedCentralMoment(Types.ExecMode.SPARK, true); |
74 | 91 | } |
75 | 92 |
|
76 | | - public void federatedCentralMoment(Types.ExecMode execMode) { |
| 93 | + public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted) { |
77 | 94 | boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; |
78 | 95 | Types.ExecMode platformOld = rtplatform; |
79 | 96 |
|
| 97 | + String TEST_NAME = isWeighted ? TEST_NAME2 : TEST_NAME1; |
80 | 98 | getAndLoadTestConfiguration(TEST_NAME); |
81 | 99 | String HOME = SCRIPT_DIR + TEST_DIR; |
82 | 100 |
|
83 | 101 | int r = rows / 4; |
| 102 | + int c = cols; |
84 | 103 |
|
85 | | - double[][] X1 = getRandomMatrix(r, 1, 1, 5, 1, 3); |
86 | | - double[][] X2 = getRandomMatrix(r, 1, 1, 5, 1, 7); |
87 | | - double[][] X3 = getRandomMatrix(r, 1, 1, 5, 1, 8); |
88 | | - double[][] X4 = getRandomMatrix(r, 1, 1, 5, 1, 9); |
| 104 | + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); |
| 105 | + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); |
| 106 | + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); |
| 107 | + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); |
89 | 108 |
|
90 | 109 | MatrixCharacteristics mc = new MatrixCharacteristics(r, 1, blocksize, r); |
91 | 110 | writeInputMatrixWithMTD("X1", X1, false, mc); |
@@ -114,24 +133,47 @@ public void federatedCentralMoment(Types.ExecMode execMode) { |
114 | 133 | if(rtplatform == Types.ExecMode.SPARK) { |
115 | 134 | DMLScript.USE_LOCAL_SPARK_CONFIG = true; |
116 | 135 | } |
117 | | - // Run reference dml script with normal matrix for Row/Col |
118 | | - fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; |
119 | | - programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), |
120 | | - expected("S"), String.valueOf(k)}; |
121 | | - runTest(null); |
122 | | - |
123 | 136 | TestConfiguration config = availableTestConfigurations.get(TEST_NAME); |
124 | 137 | loadTestConfiguration(config); |
125 | | - |
126 | | - fullDMLScriptName = HOME + TEST_NAME + ".dml"; |
127 | | - programArgs = new String[] {"-stats", "100", "-nvargs", |
128 | | - "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), |
129 | | - "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), |
130 | | - "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), |
131 | | - "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + 1, |
132 | | - "out_S=" + output("S"), "k=" + k}; |
133 | | - runTest(null); |
134 | | - |
| 138 | + if (isWeighted) { |
| 139 | + double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3); |
| 140 | + |
| 141 | + writeInputMatrixWithMTD("W1", W1, false, mc); |
| 142 | + |
| 143 | + // Run reference dml script with normal matrix |
| 144 | + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; |
| 145 | + programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), |
| 146 | + input("W1"), expected("S"), "" + k}; |
| 147 | + runTest(null); |
| 148 | + |
| 149 | + fullDMLScriptName = HOME + TEST_NAME + ".dml"; |
| 150 | + programArgs = new String[] {"-stats", "100", "-nvargs", |
| 151 | + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), |
| 152 | + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), |
| 153 | + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), |
| 154 | + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), |
| 155 | + "in_W1=" + input("W1"), |
| 156 | + "rows=" + rows, "cols=" + cols, "k=" + k, |
| 157 | + "out_S=" + output("S")}; |
| 158 | + runTest(null); |
| 159 | + } |
| 160 | + else { |
| 161 | + // Run reference dml script with normal matrix for Row/Col |
| 162 | + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; |
| 163 | + programArgs = new String[]{"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), |
| 164 | + expected("S"), String.valueOf(k)}; |
| 165 | + runTest(null); |
| 166 | + |
| 167 | + |
| 168 | + fullDMLScriptName = HOME + TEST_NAME + ".dml"; |
| 169 | + programArgs = new String[]{"-stats", "100", "-nvargs", |
| 170 | + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), |
| 171 | + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), |
| 172 | + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), |
| 173 | + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + 1, |
| 174 | + "out_S=" + output("S"), "k=" + k}; |
| 175 | + runTest(null); |
| 176 | + } |
135 | 177 | // compare all sums via files |
136 | 178 | compareResults(0.01); |
137 | 179 |
|
|
0 commit comments