Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion doc/DaphneDSL/Builtins.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,23 @@ The following built-in functions all follow the same scheme:
| `cumMin` | cumulative minimum |
| `cumMax` | cumulative maximum |

## Map

Standard element-wise mapping, as well as row- and column-wise mapping is supported.

- **`map`**`(arg:matrix, func:str)`

Element-wise mapping over a *(n x m)* matrix `arg` using a user-defined function `func` written in DaphneDSL. Applies the given UDF to each element of the given matrix.

- **`map`**`(arg:matrix, func:str, axis:si64[, udfReturnsScalar:bool])`

Row- or column-wise mapping over a *(n x m)* matrix `arg` using a user-defined function `func` written in DaphneDSL. Applies the given UDF to each row/column of the given matrix. If the input of the UDF is a row matrix, the output can be a row or a scalar; if the input is a column matrix, the output can be a column or a scalar.

- `axis` == 0: Map an entire row of the input matrix to an entire row of the output matrix; the result is a *(n x ?)* matrix
- `axis` == 1: Map an entire column of the input matrix to an entire column of the output matrix; the result is a *(? x m)* matrix
- `udfReturnsScalar` == false (optional): The given UDF `func` returns a matrix (default for row-/column-wise map), must match the UDFs output type; the result is as previously described
- `udfReturnsScalar` == true: The given UDF `func` returns a scalar; the result is a *(n x 1)* (column) or matrix *(1 x m)* (row) matrix (depending on `axis`)

## Reorganization

- **`reshape`**`(arg:matrix, numRows:size, numCols:size)`
Expand Down Expand Up @@ -701,4 +718,4 @@ These must be provided in a separate [`.meta`-file](/doc/FileMetaDataFormat.md).
- **`remove`**`(lst:list, idx:size)`

Removes the element at position `idx` (counting starts at zero) from the given list `lst`.
Returns (1) the result as a new list (the argument list stays unchanged), and (2) the removed element.
Returns (1) the result as a new list (the argument list stays unchanged), and (2) the removed element.
8 changes: 7 additions & 1 deletion src/compiler/lowering/LowerToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,17 @@ class MapOpLowering : public OpConversionPattern<daphne::MapOp> {
// Pointer to UDF
callee << "__void";

// Axis
callee << "__int64_t";

// udfReturnsScalar
callee << "__bool";

// get pointer to UDF
LLVM::LLVMFuncOp udfFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(op.getFunc());
auto udfFnPtr = rewriter.create<LLVM::AddressOfOp>(loc, udfFuncOp);

std::vector<Value> kernelOperands{op.getArg(), udfFnPtr};
std::vector<Value> kernelOperands{op.getArg(), udfFnPtr, op.getAxis(), op.getUdfReturnsScalar()};

auto kernel = rewriter.create<daphne::CallKernelOp>(loc, callee.str(), kernelOperands, op->getResultTypes());
rewriter.replaceOp(op, kernel.getResults());
Expand Down
58 changes: 49 additions & 9 deletions src/compiler/lowering/SpecializeGenericFunctionsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,19 @@ class SpecializeGenericFunctionsPass : public PassWrapper<SpecializeGenericFunct
// mapped on
mlir::Type opTy = mapOp.getArg().getType();
auto inpMatrixTy = opTy.dyn_cast<daphne::MatrixType>();
func::FuncOp specializedFunc =
createOrReuseSpecialization(inpMatrixTy.getElementType(), {}, calledFunction, mapOp.getLoc());
int64_t axis = CompilerUtils::constantOrThrow<int64_t>(mapOp.getAxis(), "map axis must be a constant.");
func::FuncOp specializedFunc;
// Set function input type based on given axis
if (axis == 0) { // row-wise map
specializedFunc =
createOrReuseSpecialization(inpMatrixTy.withShape(1, -1), {}, calledFunction, mapOp.getLoc());
} else if (axis == 1) { // column-wise map
specializedFunc =
createOrReuseSpecialization(inpMatrixTy.withShape(-1, 1), {}, calledFunction, mapOp.getLoc());
} else { // element-wise
specializedFunc =
createOrReuseSpecialization(inpMatrixTy.getElementType(), {}, calledFunction, mapOp.getLoc());
}
mapOp.setFuncAttr(specializedFunc.getSymNameAttr());

// We only allow functions that return exactly one result for
Expand All @@ -385,16 +396,45 @@ class SpecializeGenericFunctionsPass : public PassWrapper<SpecializeGenericFunct

// Get current mapOp result matrix type and fix it if needed.
// If we fixed something we rerun inference of the whole
// function
// function.
daphne::MatrixType resMatrixTy = mapOp.getType().dyn_cast<daphne::MatrixType>();
mlir::Type funcResTy = specializedFunc.getFunctionType().getResult(0);
bool madeChanges = false;

auto udfReturnsScalar = CompilerUtils::constantOrThrow<bool>(
mapOp.getUdfReturnsScalar(), "map parameter udfReturnsScalar must be a bool.");

// If the specialized function returns a scalar, the previously
// unknown dimension is set to one, and if the specialized function
// returns a matrix, this dimension is still unknown.
if (axis == 0 || axis == 1) {
if (dyn_cast_or_null<daphne::MatrixType>(funcResTy) && !udfReturnsScalar) { // Matrix -> Matrix
// Set function result type to the matrix's element
// type for further processing
funcResTy = dyn_cast<daphne::MatrixType>(funcResTy).getElementType();
} else if (!dyn_cast_or_null<daphne::MatrixType>(funcResTy) &&
udfReturnsScalar) { // Matrix -> Scalar
if (axis == 0)
resMatrixTy = resMatrixTy.withShape(inpMatrixTy.getNumRows(), 1);
else if (axis == 1)
resMatrixTy = resMatrixTy.withShape(1, inpMatrixTy.getNumCols());
madeChanges = true;
} else { // udfReturnsScalar does not match funcResTy
throw ErrorHandler::compilerError(
mapOp.getOperation(), "SpecializeGenericFunctionsPass",
"map parameter udfReturnsScalar does not match the output type of the provided function.");
}
}

// The matrix that results from the mapOp has the same
// element-type returned by the specialized function
if (resMatrixTy.getElementType() != funcResTy) {
resMatrixTy = resMatrixTy.withElementType(funcResTy);
madeChanges = true;
}

// The matrix that results from the mapOp has the same dimension
// as the input matrix and the element-type returned by the
// specialized function
if (resMatrixTy.getNumCols() != inpMatrixTy.getNumCols() ||
resMatrixTy.getNumRows() != inpMatrixTy.getNumRows() || resMatrixTy.getElementType() != funcResTy) {
mapOp.getResult().setType(inpMatrixTy.withElementType(funcResTy));
if (madeChanges) {
mapOp.getResult().setType(resMatrixTy);
inferTypesInFunction(function);
}

Expand Down
24 changes: 24 additions & 0 deletions src/ir/daphneir/DaphneInferShapeOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,30 @@ std::vector<std::pair<ssize_t, ssize_t>> daphne::RecodeOp::inferShape() {
return {{resNumRows, resNumCols}, {dictNumRows, dictNumCols}};
}

std::vector<std::pair<ssize_t, ssize_t>> daphne::MapOp::inferShape() {
mlir::Type opTy = getArg().getType();
auto inpMatrixTy = opTy.dyn_cast<daphne::MatrixType>();

// For element-wise mapOp, the result matrix has the same
// dimension as the input matrix
ssize_t resNumRows = inpMatrixTy.getNumRows();
ssize_t resNumCols = inpMatrixTy.getNumCols();

int64_t axis = CompilerUtils::constantOrThrow<int64_t>(getAxis(), "map axis must be a constant.");

// For row- and column-wise mapOp, the result matrix does not
// have the same dimensions as the input matrix.
// During a row-wise mapOp the number of rows stays the same
// as the input matrix, and during a column-wise mapOp the
// number of columns stays the same.
if (axis == 0)
resNumCols = -1;
else if (axis == 1)
resNumRows = -1;

return {{resNumRows, resNumCols}};
}

// ****************************************************************************
// Shape inference trait implementations
// ****************************************************************************
Expand Down
4 changes: 2 additions & 2 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1918,9 +1918,9 @@ def Daphne_DistributedPipelineOp : Daphne_Op<"distributedPipeline", [AttrSizedOp
// ****************************************************************************
// Higher-order operations
// ****************************************************************************
def Daphne_MapOp : Daphne_Op<"map", [ShapeFromArg]> {
def Daphne_MapOp : Daphne_Op<"map", [DeclareOpInterfaceMethods<InferShapeOpInterface>]> {
let summary = "Applies a user defined function to elements of a matrix.";
let arguments = (ins MatrixOrU:$arg, SymbolNameAttr:$func);
let arguments = (ins MatrixOrU:$arg, SymbolNameAttr:$func, SI64:$axis, BoolScalar:$udfReturnsScalar);
let results = (outs MatrixOrU:$res);
}

Expand Down
24 changes: 21 additions & 3 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,14 +1299,32 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu
// ****************************************************************************

if (func == "map") {
checkNumArgsExact(loc, func, numArgs, 2);
checkNumArgsBetween(loc, func, numArgs, 2, 4);

mlir::Value source = args[0];

auto co = args[1].getDefiningOp<mlir::daphne::ConstantOp>();
mlir::Attribute attr = co.getValue();

return static_cast<mlir::Value>(
builder.create<MapOp>(loc, source.getType(), source, attr.dyn_cast<mlir::StringAttr>()));
// Default values, if not given
mlir::Value axis = builder.create<ConstantOp>(loc, int64_t(-1));
mlir::Value udfReturnsScalar = builder.create<ConstantOp>(loc, false);

if (numArgs >= 3) { // axis is given
int64_t axisInt =
CompilerUtils::constantOrThrow<int64_t>(args[2], "third argument of map must be a constant");
if (axisInt == 0 || axisInt == 1)
axis = args[2];
else
throw ErrorHandler::compilerError(loc, "DSLBuiltins", "invalid axis for aggregation.");
}

if (numArgs == 4) { // udfReturnsScalar is given
udfReturnsScalar = args[3];
}

return static_cast<mlir::Value>(builder.create<MapOp>(
loc, source.getType(), source, attr.dyn_cast<mlir::StringAttr>(), axis, udfReturnsScalar));
}

// ****************************************************************************
Expand Down
14 changes: 12 additions & 2 deletions src/parser/daphnedsl/DaphneDSLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,9 +892,9 @@ antlrcpp::Any DaphneDSLVisitor::handleMapOpCall(DaphneDSLGrammarParser::CallExpr
throw ErrorHandler::compilerError(loc, "DSLVisitor",
"called 'handleMapOpCall' for function " + func + " instead of 'map'");

if (ctx->expr().size() != 2) {
if (ctx->expr().size() < 2 || ctx->expr().size() > 4) {
throw ErrorHandler::compilerError(loc, "DSLVisitor",
"built-in function 'map' expects exactly 2 argument(s), but got " +
"built-in function 'map' expects 2-4 argument(s), but got " +
std::to_string(ctx->expr().size()));
}

Expand All @@ -918,6 +918,16 @@ antlrcpp::Any DaphneDSLVisitor::handleMapOpCall(DaphneDSLGrammarParser::CallExpr
args.push_back(
static_cast<mlir::Value>(builder.create<mlir::daphne::ConstantOp>(loc, maybeUDF->getSymName().str())));

if (ctx->expr().size() >= 3) {
auto axis = valueOrErrorOnVisit(ctx->expr(2));
args.push_back(axis);

if (ctx->expr().size() == 4) {
auto udfReturnsScalar = utils.castBoolIf(valueOrErrorOnVisit(ctx->expr(3)));
args.push_back(udfReturnsScalar);
}
}

// Create DaphneIR operation for the built-in function.
return builtins.build(loc, func, args);
}
Expand Down
Loading