Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
public class TemplateCell extends TemplateBase
{
private static final AggOp[] SUPPORTED_AGG =
new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX};
new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD};

public TemplateCell() {
super(TemplateType.CELL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

public class TemplateRow extends TemplateBase
{
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD};
private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,
Expand Down
219 changes: 218 additions & 1 deletion src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public enum AggOp {
SUM_SQ,
MIN,
MAX,
PROD
}

protected final CellType _type;
Expand Down Expand Up @@ -332,12 +333,16 @@ private long executeDense(DenseBlock a, SideInput[] b, double[] scalars,
else if( _type == CellType.ROW_AGG ) {
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
return executeDenseRowAggSum(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
else if(_aggOp == AggOp.PROD)
return executeDenseRowProd(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
else
return executeDenseRowAggMxx(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
}
else if( _type == CellType.COL_AGG ) {
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
return executeDenseColAggSum(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
else if(_aggOp == AggOp.PROD)
return executeDenseColProd(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
else
return executeDenseColAggMxx(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
}
Expand Down Expand Up @@ -372,12 +377,16 @@ private long executeSparse(SparseBlock sblock, SideInput[] b, double[] scalars,
else if( _type == CellType.ROW_AGG ) {
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
return executeSparseRowAggSum(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
else if( _aggOp == AggOp.PROD)
return executeSparseRowProd(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
else
return executeSparseRowAggMxx(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
}
else if( _type == CellType.COL_AGG ) {
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
return executeSparseColAggSum(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
else if( _aggOp == AggOp.PROD)
return executeSparseColProd(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
else
return executeSparseColAggMxx(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
}
Expand Down Expand Up @@ -930,7 +939,215 @@ private double executeSparseAggMxx(SparseBlock sblock, SideInput[] b, double[] s
}
return ret;
}


private long executeDenseRowProd(DenseBlock a, SideInput[] b, double[] scalars,
DenseBlock c, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
{
// single block output
double[] lc = c.valuesAt(0);
long lnnz = 0;
if(a == null && !sparseSafe) {
for(int i = rl; i < ru; i++) {
for(int j = 0; j < n; j++) {
if(j == 0) {
lc[i] = genexec(0, b, scalars, m, n, rix+i, i, j);
} else if(lc[i] != 0) {
lc[i] *= genexec(0, b, scalars, m, n, rix+i, i, j);
} else {
break;
}
}
lnnz += (lc[i]!=0) ? 1 : 0;
}
}
else if( a != null ) {
for(int i = rl; i < ru; i++) {
double[] avals = a.values(i);
int aix = a.pos(i);
for(int j = 0; j < n; j++) {
double aval = avals[aix + j];
if(aval != 0 || !sparseSafe) {
if(j == 0) {
lc[i] = genexec(aval, b, scalars, m, n, rix+i, i, j);
} else if(lc[i] != 0) {
lc[i] *= genexec(aval, b, scalars, m, n, rix+i, i, j);
} else {
break;
}
} else {
break;
}
}
lnnz += (lc[i] != 0) ? 1 : 0;
}
}
return lnnz;
}

private long executeDenseColProd(DenseBlock a, SideInput[] b, double[] scalars,
DenseBlock c, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
{
double[] lc = c.valuesAt(0);
//track the cols that have a zero
boolean[] zeroFlag = new boolean[n];
if(a == null && !sparseSafe) {
for(int i = rl; i < ru; i++) {
for(int j = 0; j < n; j++) {
if(!zeroFlag[j]) {
if(i == 0) {
lc[j] = genexec(0, b, scalars, m, n, rix+i, i, j);
} else if(lc[j] != 0) {
lc[j] *= genexec(0, b, scalars, m, n, rix+i, i, j);
} else {
zeroFlag[j] = true;
}
}
}
}
}
else if(a != null) {
for(int i = rl; i < ru; i++) {
double[] avals = a.values(i);
int aix = a.pos(i);
for(int j = 0; j < n; j++) {
if(!zeroFlag[j]) {
double aval = avals[aix + j];
if(aval != 0 || !sparseSafe) {
if(i == 0) {
lc[j] = genexec(aval, b, scalars, m, n, rix + i, i, j);
} else if(lc[j] != 0) {
lc[j] *= genexec(aval, b, scalars, m, n, rix + i, i, j);
} else {
zeroFlag[j] = true;
}
}
} else {
zeroFlag[j] = true;
}
}
}
}
return -1;
}

private long executeSparseRowProd(SparseBlock sblock, SideInput[] b, double[] scalars,
MatrixBlock out, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
{
double[] c = out.getDenseBlockValues();
long lnnz = 0;
for(int i = rl; i < ru; i++) {
int lastj = -1;
if(sblock != null && !sblock.isEmpty(i)) {
int apos = sblock.pos(i);
int alen = sblock.size(i);
int[] aix = sblock.indexes(i);
double[] avals = sblock.values(i);
for(int k = apos; k < apos+alen; k++) {
if(!sparseSafe) {
for(int j=lastj+1; j<aix[k]; j++) {
if(j == 0) {
c[i] = genexec(0, b, scalars, m, n, rix+i, i, j);
} else if(c[i] != 0){
c[i] *= genexec(0, b, scalars, m, n, rix+i, i, j);
} else {
break;
}
}
}
if(aix.length == n || !sparseSafe) {
if(aix[k] == 0) {
lastj = aix[k];
c[i] = genexec(avals[k], b, scalars, m, n, rix+i, i, k);
} else if(c[i] != 0){
lastj = aix[k];
c[i] *= genexec(avals[k], b, scalars, m, n, rix+i, i, k);
} else {
break;
}
} else {
break;
}
}
}
//process remaining zeros
if(!sparseSafe)
for(int j=lastj+1; j<n; j++) {
if(j == 0) {
c[i] = genexec(0, b, scalars, m, n, rix+i, i, j);
} else if(c[i] != 0){
c[i] *= genexec(0, b, scalars, m, n, rix+i, i, j);
} else {
break;
}
}
lnnz += (c[i] != 0) ? 1 : 0;
}
return lnnz;
}

private long executeSparseColProd(SparseBlock sblock, SideInput[] b, double[] scalars,
MatrixBlock out, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
{
double[] c = out.getDenseBlockValues();
boolean[] zeroFlag = new boolean[n];

for(int i=rl; i<ru; i++) {
int lastj = -1;
//handle non-empty rows
if(sblock != null && !sblock.isEmpty(i)) {
int apos = sblock.pos(i);
int alen = sblock.size(i);
int[] aix = sblock.indexes(i);
double[] avals = sblock.values(i);
long nnzCount = sblock.size(rl, ru);
//process every column, to not miss any 0's
for(int k=apos; k<apos+alen; k++) {
//process zeros before current non-zero
if( !sparseSafe )
for(int j=lastj+1; j<aix[k]; j++) {
if(!zeroFlag[j]) {
if(i == 0) {
c[j] = genexec(0, b, scalars, m, n, rix+i, i, j);
} else if(c[j] != 0){
c[j] *= genexec(0, b, scalars, m, n, rix+i, i, j);
} else {
zeroFlag[j] = true;
}
}
}
//process current non-zero
if((nnzCount == m || !sparseSafe) && !zeroFlag[aix[k]]) {
if(i == 0) {
lastj = aix[k];
c[aix[k]] = genexec(avals[k], b, scalars, m, n, rix+i, i, lastj);
} else if(c[aix[k]] != 0){
lastj = aix[k];
c[aix[k]] *= genexec(avals[k], b, scalars, m, n, rix+i, i, lastj);
} else {
zeroFlag[aix[k]] = true;
}
} else {
zeroFlag[aix[k]] = true;
}
}
}
//process empty rows or remaining zeros
if(!sparseSafe)
for(int j=lastj+1; j<n; j++) {
if(!zeroFlag[j]) {
if(i == 0) {
c[j] = genexec(0, b, scalars, m, n, rix+i, i, j);
} else if(c[j] != 0){
c[j] *= genexec(0, b, scalars, m, n, rix+i, i, j);
} else {
zeroFlag[j] = true;
}
}
}
}
return -1;
}

//local execution where grix==rix
protected final double genexec( double a, SideInput[] b,
double[] scalars, int m, int n, int rix, int cix) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ public class CellwiseTmplTest extends AutomatedTestBase
private static final String TEST_NAME25 = TEST_NAME+25; //bias_add
private static final String TEST_NAME26 = TEST_NAME+26; //bias_mult
private static final String TEST_NAME27 = TEST_NAME+27; //outer < +7 negative
private static final String TEST_NAME28 = TEST_NAME+28; //colProds(X^2 + 1)
private static final String TEST_NAME29 = TEST_NAME+29; //colProds(2*log(X))
private static final String TEST_NAME30 = TEST_NAME+30; //rowProds(X^2 + 1)
private static final String TEST_NAME31 = TEST_NAME+31; //rowProds(2*log(X))

private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/";
Expand All @@ -79,7 +83,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
for( int i=1; i<=27; i++ ) {
for( int i=1; i<=31; i++ ) {
addTestConfiguration( TEST_NAME+i, new TestConfiguration(
TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) );
}
Expand Down Expand Up @@ -444,7 +448,7 @@ public void testCodegenCellwise26() {
public void testCodegenCellwiseRewrite26_sp() {
testCodegenIntegration( TEST_NAME26, true, ExecType.SPARK );
}

@Test
public void testCodegenCellwiseRewrite27() {
testCodegenIntegration( TEST_NAME27, true, ExecType.CP );
Expand All @@ -455,10 +459,71 @@ public void testCodegenCellwise27() {
testCodegenIntegration( TEST_NAME27, false, ExecType.CP );
}

@Test
public void testCodegenCellwiseRewrite27_sp() {
testCodegenIntegration( TEST_NAME27, true, ExecType.SPARK );
}

@Test
public void testCodegenCellwiseRewrite28() {
testCodegenIntegration( TEST_NAME28, true, ExecType.CP );
}

@Test
public void testCodegenCellwise28() {
testCodegenIntegration( TEST_NAME28, false, ExecType.CP );
}

@Test
public void testCodegenCellwiseRewrite28_sp() {
testCodegenIntegration( TEST_NAME28, true, ExecType.SPARK );
}

@Test
public void testCodegenCellwiseRewrite29() {
testCodegenIntegration( TEST_NAME29, true, ExecType.CP );
}

@Test
public void testCodegenCellwise29() {
testCodegenIntegration( TEST_NAME29, false, ExecType.CP );
}

@Test
public void testCodegenCellwiseRewrite29_sp() {
testCodegenIntegration( TEST_NAME29, true, ExecType.SPARK );
}

@Test
public void testCodegenCellwiseRewrite30() {
testCodegenIntegration( TEST_NAME30, true, ExecType.CP );
}

@Test
public void testCodegenCellwise30() {
testCodegenIntegration( TEST_NAME30, false, ExecType.CP );
}

@Test
public void testCodegenCellwiseRewrite30_sp() {
testCodegenIntegration( TEST_NAME30, true, ExecType.SPARK );
}

@Test
public void testCodegenCellwiseRewrite31() {
testCodegenIntegration( TEST_NAME31, true, ExecType.CP );
}

@Test
public void testCodegenCellwise31() {
testCodegenIntegration( TEST_NAME31, false, ExecType.CP );
}

@Test
public void testCodegenCellwiseRewrite31_sp() {
testCodegenIntegration( TEST_NAME31, true, ExecType.SPARK );
}

private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
{
boolean oldRewrites = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
Expand All @@ -467,15 +532,15 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType

if( testname.equals(TEST_NAME9) )
TEST_CONF = TEST_CONF6;

try
{
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-stats", "-args", output("S") };
programArgs = new String[]{"-explain", "codegen", "-stats", "-args", output("S") };

fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
Expand Down
Loading
Loading