diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index 03dcb33f4..3a71e4115 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -19,26 +19,13 @@ #include #include -#include - -#include "mlir/Support/LogicalResult.h" -#define GET_OP_CLASSES -#include -#define GET_TYPEDEF_CLASSES -#include -#include -#include -#include - -#include -#include - #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/Support/LogicalResult.h" // #include "mlir/IR/FunctionImplementation.h" // Removed in newer LLVM #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" @@ -54,6 +41,12 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include +#include +#include #include #include @@ -63,6 +56,12 @@ #include #include +#define GET_OP_CLASSES +#include +#define GET_TYPEDEF_CLASSES +#include +#include + struct DaphneInlinerInterface : public mlir::DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -115,200 +114,179 @@ mlir::Operation *mlir::daphne::DaphneDialect::materializeConstant(OpBuilder &bui return builder.create(loc, type, value); } -mlir::Type mlir::daphne::DaphneDialect::parseType(mlir::DialectAsmParser &parser) const { - llvm::StringRef keyword; - mlir::ParseResult pr = parser.parseKeyword(&keyword); - if (mlir::failed(pr)) - throw std::runtime_error("parsing a DaphneIR type failed"); - // `Matrix` `<` (`?` | \d+) `x` (`?` | \d+) `x` \type - // (`:` ( - // `sp` `[` \float `]` | - // `rep` `[` (`dense` | `sparse`) `]` - // ))* - if (keyword == "Matrix") { - ssize_t numRows = -1; - ssize_t numCols = -1; - double sparsity = -1.0; - MatrixRepresentation representation = MatrixRepresentation::Default; // default is dense - BoolOrUnknown symmetric = BoolOrUnknown::Unknown; - mlir::Type elementType; - if (parser.parseLess()) { - return nullptr; - } - if (parser.parseOptionalQuestion()) { - // Parse #rows if there was no '?'. - if (parser.parseInteger(numRows)) { - return nullptr; - } - } - if (parser.parseXInDimensionList()) { - return nullptr; - } - if (parser.parseOptionalQuestion()) { - // Parse #cols if there was no '?'. - if (parser.parseInteger(numCols)) { - return nullptr; - } - } - if (parser.parseXInDimensionList() || parser.parseType(elementType)) { - return nullptr; - } - // additional properties (only print/read them when present, as this - // will probably get more and more) - while (succeeded(parser.parseOptionalColon())) { - if (succeeded(parser.parseOptionalKeyword("sp"))) { - if (sparsity != -1.0) { - // read sparsity twice - return nullptr; - } - if (parser.parseLSquare() || parser.parseFloat(sparsity) || parser.parseRSquare()) { - return nullptr; - } - } else if (succeeded(parser.parseOptionalKeyword("rep"))) { - llvm::StringRef repName; - if (parser.parseLSquare() || parser.parseKeyword(&repName) || parser.parseRSquare()) { - return nullptr; - } - representation = stringToMatrixRepresentation(repName.str()); - } else if (succeeded(parser.parseOptionalKeyword("symmetric"))) { - llvm::StringRef symmetricStr; - if (parser.parseLSquare() || parser.parseKeyword(&symmetricStr) || parser.parseRSquare()) { - return nullptr; - } - symmetric = stringToBoolOrUnknown(symmetricStr.str()); - } else { - return nullptr; - } - } - if (parser.parseGreater()) { - return nullptr; - } +namespace { +mlir::LogicalResult parseOptionalDim(mlir::AsmParser &parser, ssize_t &value) { + if (succeeded(parser.parseOptionalQuestion())) + return mlir::success(); + return parser.parseInteger(value); +} +} // namespace - return MatrixType::get(parser.getBuilder().getContext(), elementType, numRows, numCols, sparsity, - representation, symmetric); - } else if (keyword == "Frame") { - ssize_t numRows = -1; - ssize_t numCols = -1; - if (parser.parseLess() || parser.parseOptionalQuestion() || - // TODO Parse #rows if there was no '?'. - // parser.parseInteger(numRows) || - parser.parseKeyword("x") || parser.parseLSquare() || parser.parseOptionalQuestion() || - // TODO Parse #cols if there was no '?'. - // parser.parseInteger(numCols) || - // TODO Parse sparsity - parser.parseColon()) { - return nullptr; +std::string unknownStrIf(ssize_t val) { return (val == -1) ? "?" : std::to_string(val); } + +std::string unknownStrIf(double val) { return (val == -1.0) ? "?" : std::to_string(val); } + +mlir::Type mlir::daphne::MatrixType::parse(mlir::AsmParser &parser) { + ssize_t numRows = -1; + ssize_t numCols = -1; + double sparsity = -1.0; + MatrixRepresentation representation = MatrixRepresentation::Default; + BoolOrUnknown symmetric = BoolOrUnknown::Unknown; + mlir::Type elementType; + + if (parser.parseLess() || failed(parseOptionalDim(parser, numRows)) || parser.parseXInDimensionList() || + failed(parseOptionalDim(parser, numCols)) || parser.parseXInDimensionList() || parser.parseType(elementType)) + return mlir::Type(); + + while (succeeded(parser.parseOptionalColon())) { + if (succeeded(parser.parseOptionalKeyword("sp"))) { + if (parser.parseLSquare() || parser.parseFloat(sparsity) || parser.parseRSquare()) + return mlir::Type(); + } else if (succeeded(parser.parseOptionalKeyword("rep"))) { + llvm::StringRef repName; + if (parser.parseLSquare() || parser.parseKeyword(&repName) || parser.parseRSquare()) + return mlir::Type(); + representation = stringToMatrixRepresentation(repName.str()); + } else if (succeeded(parser.parseOptionalKeyword("symmetric"))) { + llvm::StringRef symmetricStr; + if (parser.parseLSquare() || parser.parseKeyword(&symmetricStr) || parser.parseRSquare()) + return mlir::Type(); + symmetric = stringToBoolOrUnknown(symmetricStr.str()); + } else { + return mlir::Type(); } - std::vector cts; + } + + if (parser.parseGreater()) + return mlir::Type(); + + return MatrixType::get(parser.getBuilder().getContext(), elementType, numRows, numCols, sparsity, representation, + symmetric); +} + +void mlir::daphne::MatrixType::print(mlir::AsmPrinter &os) const { + os << "Matrix<" << unknownStrIf(getNumRows()) << 'x' << unknownStrIf(getNumCols()) << 'x' << getElementType(); + if (auto sparsity = getSparsity(); sparsity != -1.0) + os << ":sp[" << sparsity << ']'; + if (auto representation = getRepresentation(); representation != MatrixRepresentation::Default) + os << ":rep[" << matrixRepresentationToString(representation) << ']'; + if (auto symmetric = getSymmetric(); symmetric != BoolOrUnknown::Unknown) + os << ":symmetric[" << boolOrUnknownToString(symmetric) << ']'; + os << '>'; +} + +mlir::Type mlir::daphne::FrameType::parse(mlir::AsmParser &parser) { + ssize_t numRows = -1; + ssize_t numCols = -1; + llvm::SmallVector columnTypes; + + if (parser.parseLess() || failed(parseOptionalDim(parser, numRows)) || parser.parseKeyword("x") || + parser.parseLSquare() || failed(parseOptionalDim(parser, numCols)) || parser.parseColon()) + return mlir::Type(); + + if (failed(parser.parseOptionalRSquare())) { mlir::Type type; do { if (parser.parseType(type)) - return nullptr; - cts.push_back(type); + return mlir::Type(); + columnTypes.push_back(type); } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRSquare() || parser.parseGreater()) { - return nullptr; - } - return FrameType::get(parser.getBuilder().getContext(), cts, numRows, numCols, nullptr); - } else if (keyword == "Handle") { - mlir::Type dataType; - if (parser.parseLess() || parser.parseType(dataType) || parser.parseGreater()) { - return nullptr; - } - return mlir::daphne::HandleType::get(parser.getBuilder().getContext(), dataType); - } else if (keyword == "String") { - return StringType::get(parser.getBuilder().getContext()); - } else if (keyword == "Column") { - if (parser.parseLess()) - return nullptr; - ssize_t numRows = -1; - if (parser.parseOptionalQuestion()) - // Parse #rows if there was no '?'. - if (parser.parseInteger(numRows)) - return nullptr; - if (parser.parseXInDimensionList()) - return nullptr; - mlir::Type vt; - if (parser.parseType(vt)) - return nullptr; - if (parser.parseGreater()) - return nullptr; - return ColumnType::get(parser.getBuilder().getContext(), vt, numRows); - } else if (keyword == "DaphneContext") { - return mlir::daphne::DaphneContextType::get(parser.getBuilder().getContext()); - } else { - parser.emitError(parser.getCurrentLocation()) << "Parsing failed, keyword `" << keyword << "` not recognized!"; - return nullptr; + if (parser.parseRSquare()) + return mlir::Type(); } -} -std::string unknownStrIf(ssize_t val) { return (val == -1) ? "?" : std::to_string(val); } + if (parser.parseComma()) + return mlir::Type(); -std::string unknownStrIf(double val) { return (val == -1.0) ? "?" : std::to_string(val); } + std::vector *labels = nullptr; + if (failed(parser.parseOptionalQuestion())) { + std::string label; + std::vector parsedLabels; + if (parser.parseLSquare()) + return mlir::Type(); + do { + if (parser.parseString(&label)) + return mlir::Type(); + parsedLabels.push_back(label); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRSquare()) + return mlir::Type(); + // Allocate labels to keep them alive for the lifetime of the context. + labels = new std::vector(parsedLabels); + } -void mlir::daphne::DaphneDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - if (llvm::isa(type)) - os << "Structure"; - else if (auto t = llvm::dyn_cast(type)) { - os << "Matrix<" << unknownStrIf(t.getNumRows()) << 'x' << unknownStrIf(t.getNumCols()) << 'x' - << t.getElementType(); - auto sparsity = t.getSparsity(); - auto representation = t.getRepresentation(); - auto symmetric = t.getSymmetric(); - - if (sparsity != -1.0) { - os << ":sp[" << sparsity << ']'; - } - if (representation != MatrixRepresentation::Default) { - os << ":rep[" << matrixRepresentationToString(representation) << ']'; - } - if (symmetric != BoolOrUnknown::Unknown) { - os << ":symmetric[" << boolOrUnknownToString(symmetric) << ']'; - } - os << '>'; - } else if (auto t = llvm::dyn_cast(type)) { - os << "Frame<" << unknownStrIf(t.getNumRows()) << "x[" << unknownStrIf(t.getNumCols()) << ": "; - // Column types. - std::vector cts = t.getColumnTypes(); - for (size_t i = 0; i < cts.size(); i++) { - os << cts[i]; - if (i < cts.size() - 1) + if (parser.parseGreater()) + return mlir::Type(); + + return FrameType::get(parser.getBuilder().getContext(), + std::vector(columnTypes.begin(), columnTypes.end()), numRows, numCols, labels); +} + +void mlir::daphne::FrameType::print(mlir::AsmPrinter &os) const { + os << "Frame<" << unknownStrIf(getNumRows()) << "x[" << unknownStrIf(getNumCols()) << ": "; + std::vector cts = getColumnTypes(); + for (size_t i = 0; i < cts.size(); i++) { + os << cts[i]; + if (i < cts.size() - 1) + os << ", "; + } + os << "], "; + std::vector *labels = getLabels(); + if (labels) { + os << '['; + for (size_t i = 0; i < labels->size(); i++) { + os << '"' << (*labels)[i] << '"'; + if (i < labels->size() - 1) os << ", "; } - os << "], "; - // Column labels. - std::vector *labels = t.getLabels(); - if (labels) { - os << '['; - for (size_t i = 0; i < labels->size(); i++) { - os << '"' << (*labels)[i] << '"'; - if (i < labels->size() - 1) - os << ", "; - } - os << ']'; - } else - os << '?'; - os << '>'; - } else if (auto t = llvm::dyn_cast(type)) { - os << "Column<" << unknownStrIf(t.getNumRows()) << "x" << t.getValueType() << '>'; - } else if (auto t = llvm::dyn_cast(type)) { - os << "List<" << t.getElementType() << '>'; - } else if (auto handle = llvm::dyn_cast(type)) { - os << "Handle<" << handle.getDataType() << ">"; - } else if (llvm::isa(type)) - os << "String"; - else if (auto t = llvm::dyn_cast(type)) - os << "VariadicPack<" << t.getContainedType() << '>'; - else if (llvm::isa(type)) - os << "DaphneContext"; - else if (llvm::isa(type)) - os << "File"; - else if (llvm::isa(type)) - os << "Descriptor"; - else if (llvm::isa(type)) - os << "Target"; - else if (llvm::isa(type)) - os << "Unknown"; + os << ']'; + } else + os << '?'; + os << '>'; +} + +mlir::Type mlir::daphne::ColumnType::parse(mlir::AsmParser &parser) { + ssize_t numRows = -1; + mlir::Type valueType; + + if (parser.parseLess() || failed(parseOptionalDim(parser, numRows)) || parser.parseXInDimensionList() || + parser.parseType(valueType) || parser.parseGreater()) + return mlir::Type(); + + return ColumnType::get(parser.getBuilder().getContext(), valueType, numRows); +} + +void mlir::daphne::ColumnType::print(mlir::AsmPrinter &os) const { + os << "Column<" << unknownStrIf(getNumRows()) << "x" << getValueType() << '>'; +} + +mlir::Type mlir::daphne::ListType::parse(mlir::AsmParser &parser) { + mlir::Type elementType; + if (parser.parseLess() || parser.parseType(elementType) || parser.parseGreater()) + return mlir::Type(); + return ListType::get(parser.getBuilder().getContext(), elementType); +} + +void mlir::daphne::ListType::print(mlir::AsmPrinter &os) const { os << "List<" << getElementType() << '>'; } + +mlir::Type mlir::daphne::HandleType::parse(mlir::AsmParser &parser) { + mlir::Type dataType; + if (parser.parseLess() || parser.parseType(dataType) || parser.parseGreater()) + return mlir::Type(); + return HandleType::get(parser.getBuilder().getContext(), dataType); +} + +void mlir::daphne::HandleType::print(mlir::AsmPrinter &os) const { os << "Handle<" << getDataType() << ">"; } + +mlir::Type mlir::daphne::VariadicPackType::parse(mlir::AsmParser &parser) { + mlir::Type containedType; + if (parser.parseLess() || parser.parseType(containedType) || parser.parseGreater()) + return mlir::Type(); + return VariadicPackType::get(parser.getBuilder().getContext(), containedType); +} + +void mlir::daphne::VariadicPackType::print(mlir::AsmPrinter &os) const { + os << "VariadicPack<" << getContainedType() << '>'; } std::string mlir::daphne::matrixRepresentationToString(MatrixRepresentation rep) { diff --git a/src/ir/daphneir/DaphneDialect.td b/src/ir/daphneir/DaphneDialect.td index 92c5be22f..bb959e093 100644 --- a/src/ir/daphneir/DaphneDialect.td +++ b/src/ir/daphneir/DaphneDialect.td @@ -32,19 +32,7 @@ def Daphne_Dialect : Dialect { let cppNamespace = "::mlir::daphne"; let hasConstantMaterializer = 1; - - // TODO It should not be necessary to add the declarations of parseType() - // and printType() manually. However, I cannot find a way to have only the - // declarations (but not the definitions) generated with this version of - // MLIR. Using useDefaultTypePrinterParser messes things up. - let extraClassDeclaration = [{ - /// Parse a type registered to this dialect. - ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; - - /// Print a type registered to this dialect. - void printType(::mlir::Type type, - ::mlir::DialectAsmPrinter &os) const override; - }]; + let useDefaultTypePrinterParser = 1; } -#endif //SRC_IR_DAPHNEIR_DAPHNEDIALECT_TD \ No newline at end of file +#endif //SRC_IR_DAPHNEIR_DAPHNEDIALECT_TD diff --git a/src/ir/daphneir/DaphneTypes.td b/src/ir/daphneir/DaphneTypes.td index 862d6047d..7c6c13ef0 100644 --- a/src/ir/daphneir/DaphneTypes.td +++ b/src/ir/daphneir/DaphneTypes.td @@ -27,6 +27,7 @@ include "mlir/IR/EnumAttr.td" // The base for all DaphneIR-specific types. class Daphne_Type : TypeDef { let typeName = name; + let mnemonic = name; } // **************************************************************************** @@ -55,6 +56,7 @@ def Structure : Daphne_Type<"Structure"> { // A matrix type. def Matrix : Daphne_Type<"Matrix"> { let summary = "matrix"; + let hasCustomAssemblyFormat = 1; // TODO Maybe we should rename "element type" to "value type" everywhere. let parameters = (ins @@ -156,6 +158,7 @@ class MatrixOf allowedTypes> : AnyTypeOf<[ def Frame : Daphne_Type<"Frame"> { let summary = "frame"; + let hasCustomAssemblyFormat = 1; let parameters = (ins "std::vector<::mlir::Type>":$columnTypes, @@ -238,6 +241,7 @@ def FrameOrU : AnyTypeOf<[Frame, Unknown]>; def Column : Daphne_Type<"Column"> { let summary = "column"; + let hasCustomAssemblyFormat = 1; let parameters = (ins "::mlir::Type":$valueType, "ssize_t":$numRows); @@ -306,6 +310,7 @@ class ColumnOf allowedTypes> : AnyTypeOf<[ def List : Daphne_Type<"List"> { let summary = "list"; + let hasCustomAssemblyFormat = 1; let parameters = (ins "::mlir::Type":$elementType @@ -320,6 +325,7 @@ def ListOrU : AnyTypeOf<[List, Unknown]>; def Handle : Daphne_Type<"Handle"> { let summary = "handle to distributed data"; + let hasCustomAssemblyFormat = 1; let parameters = (ins "::mlir::Type":$dataType); let genVerifyDecl = 1; @@ -384,6 +390,7 @@ def VariadicPack : Daphne_Type<"VariadicPack"> { stack. It is introduced during the lowering to kernel calls and converted when lowering to the LLVM dialect. }]; + let hasCustomAssemblyFormat = 1; let parameters = (ins "::mlir::Type":$containedType); }