Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/DaphneDSL/Builtins.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ The following built-in functions allow to find out meta data of matrices and fra

Returns the DAPHNE compiler's *estimate* of the argument's sparsity.
Note that this value may deviate from the *actual* sparsity of the data at run-time.

- **`isSymmetric`**`(arg:matrix)`

Returns `true` if and only if the given *(n x n)* matrix is symmetric, i.e., for any *i in {0, 1, ..., n-1}* and *j in {0, 1, ..., n-1}*, the value at position *(i, j)* is the same as the value at position *(j, i)*.
The given matrix must be a square matrix.

## Elementwise unary

Expand Down
9 changes: 9 additions & 0 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ def SparsityOp : Daphne_Op<"sparsity", [DataTypeSca]> {
let hasCanonicalizeMethod = 1;
}

def Daphne_IsSymmetricOp : Daphne_Op<"isSymmetric"> {
let summary = "Checks if a matrix is symmetric";
let description = [{
This operation checks if the input matrix is symmetric.
}];
let arguments = (ins MatrixOrU:$arg);
let results = (outs BoolScalar:$res);
}

// ****************************************************************************
// Matrix multiplication
// ****************************************************************************
Expand Down
8 changes: 8 additions & 0 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,14 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu
mlir::Value vec = args[1];
return utils.retValWithInferedType(builder.create<GemvOp>(loc, utils.unknownType, mat, vec));
}
if (func == "isSymmetric") {
// Check the function receives exactly one argument
checkNumArgsExact(loc, func, numArgs, 1);

// Create the IsSymmetricOp in the IR
mlir::Value arg = args[0];
return static_cast<mlir::Value>(builder.create<mlir::daphne::IsSymmetricOp>(loc, builder.getI1Type(), arg));
}

// ********************************************************************
// Extended relational algebra
Expand Down
25 changes: 25 additions & 0 deletions src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -6956,6 +6956,31 @@
]
}
]
},
{
"kernelTemplate": {
"header": "IsSymmetric.h",
"opName": "isSymmetric",
"returnType": "bool",
"templateParams": [
{
"name": "DTArg",
"isDataType": true
}
],
"runtimeParams": [
{
"type": "const DTArg *",
"name": "arg"
}
]
},
"instantiations": [
[["DenseMatrix", "double"]],
[["DenseMatrix", "int64_t"]],
[["CSRMatrix", "double"]],
[["CSRMatrix", "int64_t"]]
]
}

]
1 change: 1 addition & 0 deletions test/api/cli/operations/OperationsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ MAKE_TEST_CASE("idxMax", 1)
MAKE_TEST_CASE("idxMin", 1)
MAKE_TEST_CASE("innerJoin", 1)
MAKE_TEST_CASE("isNan", 1)
MAKE_TEST_CASE("isSymmetric", 1)
MAKE_TEST_CASE("lower", 1)
MAKE_TEST_CASE("mean", 1)
MAKE_TEST_CASE("oneHot", 1)
Expand Down
5 changes: 5 additions & 0 deletions test/api/cli/operations/isSymmetric_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
X = [1, 2, 2, 1](2, 2);
Y = [1, 2, 3, 4](2, 2);

print(isSymmetric(X));
print(isSymmetric(Y));
2 changes: 2 additions & 0 deletions test/api/cli/operations/isSymmetric_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
0