Skip to content

Commit 8b4f3ce

Browse files
committed
column masks
1 parent f4aa252 commit 8b4f3ce

File tree

12 files changed

+347
-1
lines changed

12 files changed

+347
-1
lines changed

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ public enum Builtins {
153153
GARCH("garch", true),
154154
GAUSSIAN_CLASSIFIER("gaussianClassifier", true),
155155
GET_ACCURACY("getAccuracy", true),
156+
GET_CATEGORICAL_MASK("getCategoricalMask", false),
156157
GLM("glm", true),
157158
GLM_PREDICT("glmPredict", true),
158159
GLOVE("glove", true),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ public enum Opcodes {
194194
TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin),
195195
TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin),
196196

197+
GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary),
198+
197199
//Ternary instruction opcodes
198200
PM("+*", InstructionType.Ternary),
199201
MINUSMULT("-*", InstructionType.Ternary),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ public enum OpOp2 {
635635
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
636636
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
637637
MINUS1_MULT(false), //1-X*Y
638+
GET_CATEGORICAL_MASK(false), // get transformation mask
638639
QUANTIZE_COMPRESS(false); //quantization-fused compression
639640

640641
private final boolean _validOuter;

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,9 @@ else if( (op == OpOp2.CBIND && getDataType().isList())
864864
_etype = ExecType.CP;
865865
}
866866

867+
if( op == OpOp2.GET_CATEGORICAL_MASK)
868+
_etype = ExecType.CP;
869+
867870
//mark for recompile (forever)
868871
setRequiresRecompileIfNecessary();
869872

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,6 +2013,11 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
20132013
else
20142014
raiseValidateError("The compress or decompress instruction is not allowed in dml scripts");
20152015
break;
2016+
case GET_CATEGORICAL_MASK:
2017+
checkNumParameters(2);
2018+
checkFrameParam(getFirstExpr());
2019+
checkScalarParam(getSecondExpr());
2020+
break;
20162021
case QUANTIZE_COMPRESS:
20172022
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) {
20182023
checkNumParameters(2);
@@ -2333,6 +2338,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional
23332338
raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
23342339
}
23352340
}
2341+
2342+
protected void checkFrameParam(Expression e) {
2343+
if(e.getOutput().getDataType() != DataType.FRAME) {
2344+
raiseValidateError("Expecting frame parameter for function " + getOpCode(), false,
2345+
LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
2346+
}
2347+
}
23362348

23372349
protected void checkMatrixScalarParam(Expression e) { //always unconditional
23382350
if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) {

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,6 +2817,9 @@ else if ( in.length == 2 )
28172817
DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr);
28182818
break;
28192819

2820+
case GET_CATEGORICAL_MASK:
2821+
currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOp2.GET_CATEGORICAL_MASK, expr, expr2);
2822+
break;
28202823
default:
28212824
throw new ParseException("Unsupported builtin function type: "+source.getOpCode());
28222825
}

src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
5151
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
5252
STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
5353
TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE,
54-
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
54+
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK,
5555
MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
5656

5757

@@ -113,6 +113,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
113113
String2BuiltinCode.put( "_map", BuiltinCode.MAP);
114114
String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP);
115115
String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA);
116+
String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK);
116117
}
117118

118119
protected Builtin(BuiltinCode bf) {

src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T
5959
return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str);
6060
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME)
6161
return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str);
62+
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR)
63+
return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str);
6264
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX)
6365
return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str);
6466
else
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.cp;
21+
22+
import java.util.Arrays;
23+
24+
import org.apache.sysds.common.Builtins;
25+
import org.apache.sysds.common.Types.ValueType;
26+
import org.apache.sysds.runtime.DMLRuntimeException;
27+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
28+
import org.apache.sysds.runtime.frame.data.FrameBlock;
29+
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
30+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
31+
import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
32+
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
33+
import org.apache.sysds.runtime.util.UtilFunctions;
34+
import org.apache.wink.json4j.JSONArray;
35+
import org.apache.wink.json4j.JSONObject;
36+
37+
public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction {
38+
// private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName());
39+
40+
protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out,
41+
String opcode, String istr) {
42+
super(CPType.Binary, op, in1, in2, out, opcode, istr);
43+
}
44+
45+
@Override
46+
public void processInstruction(ExecutionContext ec) {
47+
// get input frames
48+
FrameBlock inBlock1 = ec.getFrameInput(input1.getName());
49+
ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true);
50+
if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) {
51+
processGetCategorical(ec, inBlock1, spec);
52+
}
53+
else {
54+
throw new DMLRuntimeException("Unsupported operation");
55+
}
56+
57+
// Release the memory occupied by input frames
58+
ec.releaseFrameInput(input1.getName());
59+
}
60+
61+
public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) {
62+
try {
63+
64+
// MatrixBlock ret = new MatrixBlock();
65+
int nCol = f.getNumColumns();
66+
67+
// System.out.println(spec);
68+
JSONObject jSpec = new JSONObject(spec.getStringValue());
69+
// System.out.println(jSpec);
70+
if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) {
71+
throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask");
72+
}
73+
74+
String recode = TfMethod.RECODE.toString();
75+
String dummycode = TfMethod.DUMMYCODE.toString();
76+
String hashCode = TfMethod.HASH.toString();
77+
78+
System.out.println(jSpec.keySet());
79+
80+
int[] lengths = new int[nCol];
81+
// assume all columns encode to at least one column.
82+
Arrays.fill(lengths, 1);
83+
boolean[] categorical = new boolean[nCol];
84+
85+
if(jSpec.containsKey(recode)) {
86+
JSONArray a = jSpec.getJSONArray(recode);
87+
for(Object aa : a) {
88+
int av = (Integer) aa - 1;
89+
categorical[av] = true;
90+
}
91+
}
92+
93+
if(jSpec.containsKey(dummycode)) {
94+
JSONArray a = jSpec.getJSONArray(dummycode);
95+
for(Object aa : a) {
96+
int av = (Integer) aa - 1;
97+
ColumnMetadata d = f.getColumnMetadata()[av];
98+
String v = f.getString(0, av);
99+
int ndist;
100+
if(v.length() > 1 && v.charAt(0) == '¿') {
101+
ndist = UtilFunctions.parseToInt(v.substring(1));
102+
}
103+
else {
104+
ndist = d.isDefault() ? 0 : (int) d.getNumDistinct();
105+
}
106+
lengths[av] = ndist;
107+
categorical[av] = true;
108+
}
109+
}
110+
111+
// get total size after mapping
112+
113+
int sumLengths = 0;
114+
for(int i : lengths) {
115+
sumLengths += i;
116+
}
117+
118+
MatrixBlock ret = new MatrixBlock(1, sumLengths, false);
119+
ret.allocateDenseBlock();
120+
int off = 0;
121+
for(int i = 0; i < lengths.length; i++) {
122+
for(int j = 0; j < lengths[i]; j++) {
123+
ret.set(0, off++, categorical[i] ? 1 : 0);
124+
}
125+
}
126+
127+
ec.setMatrixOutput(output.getName(), ret);
128+
129+
}
130+
catch(Exception e) {
131+
throw new DMLRuntimeException(e);
132+
}
133+
}
134+
}

src/test/java/org/apache/sysds/test/TestUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.io.FileInputStream;
3333
import java.io.FileOutputStream;
3434
import java.io.FileReader;
35+
import java.io.FileWriter;
3536
import java.io.IOException;
3637
import java.io.InputStreamReader;
3738
import java.io.OutputStreamWriter;
@@ -2927,6 +2928,25 @@ public static void writeTestScalar(String file, double value) {
29272928
}
29282929
}
29292930

2931+
2932+
/**
2933+
* Write scalar to file
2934+
*
2935+
* @param file File to write to
2936+
* @param value Value to write
2937+
*/
2938+
public static void writeTestScalar(String file, String value) {
2939+
try {
2940+
DataOutputStream out = new DataOutputStream(new FileOutputStream(file));
2941+
try(PrintWriter pw = new PrintWriter(out)) {
2942+
pw.println(value);
2943+
}
2944+
}
2945+
catch(IOException e) {
2946+
fail("unable to write test scalar (" + file + "): " + e.getMessage());
2947+
}
2948+
}
2949+
29302950
/**
29312951
* Write scalar to file
29322952
*

0 commit comments

Comments
 (0)