Skip to content

Commit 6c4bffd

Browse files
committed
[SYSTEMDS-3819] New sliceLineExtract builtin function
This new sliceLineExtract builtin functions allows to take the output of sliceLine and extract the rows from X and e which belong to the top k2 <= k slices.
1 parent c5ab81c commit 6c4bffd

File tree

5 files changed

+72
-4
lines changed

5 files changed

+72
-4
lines changed

scripts/builtin/sliceLineDebug.dml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# INPUT:
2727
# ------------------------------------------------------------------------------
2828
# TK top-k slices (k x ncol(X) if successful)
29-
# TKC score, size, error of slices (k x 3)
29+
# TKC score, total/max error, size of slices (k x 4)
3030
# tfmeta transformencode meta data
3131
# tfspec transform specification
3232
# ------------------------------------------------------------------------------
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# This builtin function takes the outputs of SliceLine and allows
23+
#
24+
#
25+
# INPUT:
26+
# ------------------------------------------------------------------------------
27+
# X Feature matrix in recoded/binned representation
28+
# e Error vector of trained model
29+
# TK top-k slices (k x ncol(X) if successful)
30+
# TKC score, total/max error, size of slices (k x 4)
31+
# k2 fist k2 slices to extract with k2 <= k
32+
# ------------------------------------------------------------------------------
33+
#
34+
# OUTPUT:
35+
# ------------------------------------------------------------------------------
36+
# Xtk Selected rows from X which belong to k2 top slices
37+
# etk Selected rows from e which belong to k2 top slices
38+
# ------------------------------------------------------------------------------
39+
40+
m_sliceLineExtract = function(Matrix[Double] X, Matrix[Double] e,
41+
Matrix[Double] TK, Matrix[Double] TKC, Integer k2 = -1)
42+
return(Matrix[Double] Xtk, Matrix[Double] etk)
43+
{
44+
# check valid parameters
45+
if( k2 > nrow(TK) )
46+
stop("sliceLineExtract: invalid number of slices to extract: "+k2+" > "+nrow(TK)).
47+
if( k2 <= 0 )
48+
k2 = nrow(TK);
49+
50+
# extract first k2 slices from X and e
51+
I = matrix(0, k2, nrow(X));
52+
parfor(i in 1:k2) {
53+
I[i,] = t(rowSums(X == TK[i,]) == sum(TK[i,]))
54+
}
55+
I = t(colSums(I)); #union
56+
57+
Xtk = removeEmpty(target=X, margin="rows", select=I);
58+
etk = removeEmpty(target=e, margin="rows", select=I);
59+
}
60+

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ public enum Builtins {
312312
SLICEFINDER("slicefinder", true), //TODO remove
313313
SLICELINE("sliceLine", true),
314314
SLICELINE_DEBUG("sliceLineDebug", true),
315+
SLICELINE_EXTRACT("sliceLineExtract", true),
315316
SKEWNESS("skewness", true),
316317
SMAPE("smape", true),
317318
SMOTE("smote", true),

src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSliceLineRealDataTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class BuiltinSliceLineRealDataTest extends AutomatedTestBase {
3939
@Override
4040
public void setUp() {
4141
for(int i=1; i<=1; i++)
42-
addTestConfiguration(TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
42+
addTestConfiguration(TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R","V"}));
4343
}
4444

4545
@Test
@@ -55,12 +55,14 @@ private void runSliceLine(int test, String data, String tfspec, double minAcc, E
5555
String HOME = SCRIPT_DIR + TEST_DIR;
5656
fullDMLScriptName = HOME + TEST_NAME + ".dml";
5757
programArgs = new String[] {"-stats",
58-
"-args", data, tfspec, output("R")};
58+
"-args", data, tfspec, output("R"), output("V")};
5959

6060
runTest(true, false, null, -1);
6161

6262
double acc = readDMLMatrixFromOutputDir("R").get(new CellIndex(1,1));
63+
double val = readDMLMatrixFromOutputDir("V").get(new CellIndex(1,1));
6364
Assert.assertTrue(acc >= minAcc);
65+
Assert.assertTrue(val >= 0.99);
6466
Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
6567
}
6668
finally {

src/test/scripts/functions/builtin/sliceLineRealData.dml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,14 @@ acc = lmPredictStats(yhat, y, TRUE);
4545
e = (y-yhat)^2;
4646

4747
# model debugging via sliceline
48-
[TK,TKC,D] = slicefinder(X=X, e=e, k=4, alpha=0.95, minSup=32, tpBlksz=16, verbose=TRUE)
48+
[TK,TKC,D] = sliceLine(X=X, e=e, k=4, alpha=0.95, minSup=32, tpBlksz=16, verbose=TRUE)
4949
tfspec2 = "{ ids:true, recode:[1,2,5], bin:[{id:3, method:equi-width, numbins:10},{id:4, method:equi-width, numbins:10}]}"
5050
XYZ = sliceLineDebug(TK=TK, TKC=TKC, tfmeta=meta, tfspec=tfspec2)
51+
[Xtk,etk] = sliceLineExtract(X=X, e=e, TK=TK, TKC=TKC, k2=3);
5152

5253
acc = acc[3,1];
54+
val = as.matrix((sum(TKC[1,4]) >= nrow(Xtk)) & (nrow(Xtk) == nrow(etk)))
55+
5356
write(acc, $3);
57+
write(val, $4);
58+

0 commit comments

Comments
 (0)