Skip to content

Commit 615cd9a

Browse files
committed
[SYSTEMDS-3823] Compression test case for bultin kmeans
Closes #2194
1 parent 41c21bf commit 615cd9a

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,9 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
17301730
* (the invoker is responsible to recompute nnz after all copies are done)
17311731
*/
17321732
public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, boolean awareDestNZ ) {
1733+
if (src instanceof CompressedMatrixBlock){
1734+
src = ((CompressedMatrixBlock) src).getUncompressed("In-place matrix copy into indexed matrix");
1735+
}
17331736
if(sparse && src.sparse)
17341737
copySparseToSparse(rl, ru, cl, cu, src, awareDestNZ);
17351738
else if(sparse && !src.sparse)

src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import static org.junit.Assert.fail;
2323

24+
import java.io.ByteArrayOutputStream;
2425
import java.io.File;
2526

2627
import org.apache.commons.logging.Log;
@@ -46,6 +47,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
4647
private final static String TEST_NAME5 = "WorkloadAnalysisSliceFinder";
4748
private final static String TEST_NAME6 = "WorkloadAnalysisLmCG";
4849
private final static String TEST_NAME7 = "WorkloadAnalysisL2SVM";
50+
private final static String TEST_NAME8 = "WorkloadAnalysisKmeans";
4951
private final static String TEST_DIR = "functions/compress/workload/";
5052
private final static String TEST_CLASS_DIR = TEST_DIR + WorkloadAnalysisTest.class.getSimpleName() + "/";
5153

@@ -73,6 +75,7 @@ public void setUp() {
7375
addTestConfiguration(TEST_NAME5, new TestConfiguration(dir, TEST_NAME5, new String[] {"B"}));
7476
addTestConfiguration(TEST_NAME6, new TestConfiguration(dir, TEST_NAME6, new String[] {"B"}));
7577
addTestConfiguration(TEST_NAME7, new TestConfiguration(dir, TEST_NAME7, new String[] {"B"}));
78+
addTestConfiguration(TEST_NAME8, new TestConfiguration(dir, TEST_NAME8, new String[] {"B"}));
7679
}
7780

7881
@Test
@@ -143,8 +146,23 @@ public void testL2SVMCP() {
143146
runWorkloadAnalysisTest(TEST_NAME7, ExecMode.SINGLE_NODE, 2, false);
144147
}
145148

149+
@Test
150+
public void testKmeansSuccessfulCP() {
151+
runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1, false, 30);
152+
}
153+
154+
@Test
155+
public void testKmeansUnsuccessfulCP() {
156+
runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1, false, 10);
157+
}
158+
159+
private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates){
160+
runWorkloadAnalysisTest(testname, mode, compressionCount, intermediates, -1);
161+
}
162+
146163
// private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount) {
147-
private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates) {
164+
private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates,
165+
int maxIter) {
148166
ExecMode oldPlatform = setExecMode(mode);
149167
boolean oldIntermediates = WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;
150168

@@ -154,19 +172,20 @@ private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compres
154172

155173
String HOME = SCRIPT_DIR + TEST_DIR;
156174
fullDMLScriptName = HOME + testname + ".dml";
157-
programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B")};
175+
programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B"),
176+
String.valueOf(maxIter)};
158177

159178
writeInputMatrixWithMTD("X", X, false);
160179
writeInputMatrixWithMTD("y", y, false);
161180

162-
String ret = runTest(null).toString();
181+
ByteArrayOutputStream out = runTest(null);
182+
String ret = out != null ? out.toString() : "";
163183
LOG.debug(ret);
164184

165185
// check various additional expectations
166186
long actualCompressionCount = (mode == ExecMode.HYBRID || mode == ExecMode.SINGLE_NODE) ? Statistics
167187
.getCPHeavyHitterCount("compress") : Statistics.getCPHeavyHitterCount("sp_compress");
168-
169-
Assert.assertEquals("Assert that the compression counts expeted matches actual: " + compressionCount + " vs "
188+
Assert.assertEquals("Assert that the compression counts expected matches actual: " + compressionCount + " vs "
170189
+ actualCompressionCount, compressionCount, actualCompressionCount);
171190
if(compressionCount > 0)
172191
Assert.assertTrue(mode == ExecMode.SINGLE_NODE || mode == ExecMode.HYBRID ? heavyHittersContainsString(
@@ -176,6 +195,7 @@ private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compres
176195

177196
}
178197
catch(Exception e) {
198+
e.printStackTrace();
179199
resetExecMode(oldPlatform);
180200
fail("Failed workload test");
181201
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = read($1);
23+
24+
25+
print("")
26+
print("kmeans")
27+
28+
[data, Centering, ScaleFactor] = scale(X, TRUE, TRUE)
29+
# terminates with result
30+
[Y_n, C_n] = kmeans(X=data, k=16, runs= 1, max_iter=as.integer($4), eps= 1e-17, seed= 13, is_verbose=TRUE)
31+
print(sum(Y_n))

0 commit comments

Comments
 (0)