diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index 2da79011fa26a..4bd7f12fabf7b 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon) add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) +add_subdirectory(Polynomial) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/include/mlir/Dialect/Polynomial/CMakeLists.txt b/mlir/include/mlir/Dialect/Polynomial/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Polynomial/IR/CMakeLists.txt new file mode 100644 index 0000000000000..d8039deb5ee21 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect(Polynomial polynomial) +add_mlir_doc(PolynomialDialect PolynomialDialect Polynomial/ -gen-dialect-doc) +add_mlir_doc(PolynomialOps PolynomialOps Polynomial/ -gen-op-doc) +add_mlir_doc(PolynomialAttributes PolynomialAttributes Dialects/ -gen-attrdef-doc) +add_mlir_doc(PolynomialTypes PolynomialTypes Dialects/ -gen-typedef-doc) + +set(LLVM_TARGET_DEFINITIONS Polynomial.td) +mlir_tablegen(PolynomialAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=polynomial) +mlir_tablegen(PolynomialAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=polynomial) +add_public_tablegen_target(MLIRPolynomialAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h new file mode 100644 index 0000000000000..39b05b9d3ad14 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h @@ -0,0 +1,130 @@ +//===- Polynomial.h - A data class for polynomials --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_ +#define MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +class MLIRContext; + +namespace polynomial { + +/// This restricts statically defined polynomials to have at most 64-bit +/// coefficients. This may be relaxed in the future, but it seems unlikely one +/// would want to specify 128-bit polynomials statically in the source code. +constexpr unsigned apintBitWidth = 64; + +/// A class representing a monomial of a single-variable polynomial with integer +/// coefficients. +class Monomial { +public: + Monomial(int64_t coeff, uint64_t expo) + : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {} + + Monomial(const APInt &coeff, const APInt &expo) + : coefficient(coeff), exponent(expo) {} + + Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {} + + bool operator==(const Monomial &other) const { + return other.coefficient == coefficient && other.exponent == exponent; + } + bool operator!=(const Monomial &other) const { + return other.coefficient != coefficient || other.exponent != exponent; + } + + /// Monomials are ordered by exponent. + bool operator<(const Monomial &other) const { + return (exponent.ult(other.exponent)); + } + + // Prints polynomial to 'os'. + void print(raw_ostream &os) const; + + friend ::llvm::hash_code hash_value(const Monomial &arg); + +public: + APInt coefficient; + + // Always unsigned + APInt exponent; +}; + +/// A single-variable polynomial with integer coefficients. +/// +/// Eg: x^1024 + x + 1 +/// +/// The symbols used as the polynomial's indeterminate don't matter, so long as +/// it is used consistently throughout the polynomial. +class Polynomial { +public: + Polynomial() = delete; + + explicit Polynomial(ArrayRef terms) : terms(terms){}; + + // Returns a Polynomial from a list of monomials. + // Fails if two monomials have the same exponent. + static FailureOr fromMonomials(ArrayRef monomials); + + /// Returns a polynomial with coefficients given by `coeffs`. The value + /// coeffs[i] is converted to a monomial with exponent i. + static Polynomial fromCoefficients(ArrayRef coeffs); + + explicit operator bool() const { return !terms.empty(); } + bool operator==(const Polynomial &other) const { + return other.terms == terms; + } + bool operator!=(const Polynomial &other) const { + return !(other.terms == terms); + } + + // Prints polynomial to 'os'. + void print(raw_ostream &os) const; + void print(raw_ostream &os, ::llvm::StringRef separator, + ::llvm::StringRef exponentiation) const; + void dump() const; + + // Prints polynomial so that it can be used as a valid identifier + std::string toIdentifier() const; + + unsigned getDegree() const; + + friend ::llvm::hash_code hash_value(const Polynomial &arg); + +private: + // The monomial terms for this polynomial. + SmallVector terms; +}; + +// Make Polynomial hashable. +inline ::llvm::hash_code hash_value(const Polynomial &arg) { + return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end()); +} + +inline ::llvm::hash_code hash_value(const Monomial &arg) { + return llvm::hash_combine(::llvm::hash_value(arg.coefficient), + ::llvm::hash_value(arg.exponent)); +} + +inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) { + polynomial.print(os); + return os; +} + +} // namespace polynomial +} // namespace mlir + +#endif // MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_ diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td new file mode 100644 index 0000000000000..5d8da8399b01b --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -0,0 +1,153 @@ +//===- PolynomialOps.td - Polynomial dialect ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef POLYNOMIAL_OPS +#define POLYNOMIAL_OPS + +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def Polynomial_Dialect : Dialect { + let name = "polynomial"; + let cppNamespace = "::mlir::polynomial"; + let description = [{ + The Polynomial dialect defines single-variable polynomial types and + operations. + + The simplest use of `polynomial` is to represent mathematical operations in + a polynomial ring `R[x]`, where `R` is another MLIR type like `i32`. + + More generally, this dialect supports representing polynomial operations in a + quotient ring `R[X]/(f(x))` for some statically fixed polynomial `f(x)`. + Two polyomials `p(x), q(x)` are considered equal in this ring if they have the + same remainder when dividing by `f(x)`. When a modulus is given, ring operations + are performed with reductions modulo `f(x)` and relative to the coefficient ring + `R`. + + Examples: + + ```mlir + // A constant polynomial in a ring with i32 coefficients and no polynomial modulus + #ring = #polynomial.ring + %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> + + // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1) + #modulus = #polynomial.polynomial<1 + x**1024> + #ring = #polynomial.ring + %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> + + // A constant polynomial in a ring with i32 coefficients, with a polynomial + // modulus of (x^1024 + 1) and a coefficient modulus of 17. + #modulus = #polynomial.polynomial<1 + x**1024> + #ring = #polynomial.ring + %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> + ``` + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; +} + +class Polynomial_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> { + let summary = "An attribute containing a single-variable polynomial."; + let description = [{ + #poly = #polynomial.poly + }]; + let parameters = (ins "Polynomial":$polynomial); + let hasCustomAssemblyFormat = 1; +} + +def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { + let summary = "An attribute specifying a polynomial ring."; + let description = [{ + A ring describes the domain in which polynomial arithmetic occurs. The ring + attribute in `polynomial` represents the more specific case of polynomials + with a single indeterminate; whose coefficients can be represented by + another MLIR type (`coefficientType`); and, if the coefficient type is + integral, whose coefficients are taken modulo some statically known modulus + (`coefficientModulus`). + + Additionally, a polynomial ring can specify an _ideal_, which converts + polynomial arithmetic to the analogue of modular integer arithmetic, where + each polynomial is represented as its remainder when dividing by the + modulus. For single-variable polynomials, an "ideal" is always specificed + via a single polynomial, which we call `polynomialModulus`. + + An expressive example is polynomials with i32 coefficients, whose + coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of + `x**1024 - 1`. + + ```mlir + #poly_mod = #polynomial.polynomial<-1 + x**1024> + #ring = #polynomial.ring + + %0 = ... : polynomial.polynomial<#ring> + ``` + + In this case, the value of a polynomial is always "converted" to a + canonical form by applying repeated reductions by setting `x**1024 = 1` + and simplifying. + + The coefficient and polynomial modulus parameters are optional, and the + coefficient modulus is only allowed if the coefficient type is integral. + }]; + + let parameters = (ins + "Type": $coefficientType, + OptionalParameter<"IntegerAttr">: $coefficientModulus, + OptionalParameter<"PolynomialAttr">: $polynomialModulus + ); + + let hasCustomAssemblyFormat = 1; +} + +class Polynomial_Type + : TypeDef { + let mnemonic = typeMnemonic; +} + +def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> { + let summary = "An element of a polynomial ring."; + + let description = [{ + A type for polynomials in a polynomial quotient ring. + }]; + + let parameters = (ins Polynomial_RingAttr:$ring); + let assemblyFormat = "`<` $ring `>`"; +} + +class Polynomial_Op traits = []> : + Op; + +class Polynomial_UnaryOp traits = []> : + Polynomial_Op { + let arguments = (ins Polynomial_PolynomialType:$operand); + let results = (outs Polynomial_PolynomialType:$result); + + let assemblyFormat = "$operand attr-dict `:` qualified(type($result))"; +} + +class Polynomial_BinaryOp traits = []> : + Polynomial_Op { + let arguments = (ins Polynomial_PolynomialType:$lhs, Polynomial_PolynomialType:$rhs); + let results = (outs Polynomial_PolynomialType:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($result))"; +} + +#endif // POLYNOMIAL_OPS diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h new file mode 100644 index 0000000000000..b37d17bb89fb2 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h @@ -0,0 +1,17 @@ +//===- PolynomialAttributes.h - polynomial dialect attributes ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_ +#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_ + +#include "Polynomial.h" +#include "PolynomialDialect.h" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h.inc" + +#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_ diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h new file mode 100644 index 0000000000000..7b7acebe7a93b --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h @@ -0,0 +1,19 @@ +//===- PolynomialDialect.h - The Polynomial dialect -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_ +#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" + +// Generated headers (block clang-format from messing up order) +#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h.inc" + +#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialOps.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialOps.h new file mode 100644 index 0000000000000..bacaad81ce8e5 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialOps.h @@ -0,0 +1,21 @@ +//===- PolynomialOps.h - Ops for the Polynomial dialect ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_ +#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_ + +#include "PolynomialDialect.h" +#include "PolynomialTypes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Polynomial/IR/Polynomial.h.inc" + +#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_ diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h new file mode 100644 index 0000000000000..2fc6877452547 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h @@ -0,0 +1,17 @@ +//===- PolynomialTypes.h - Types for the Polynomial dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALTYPES_H_ +#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALTYPES_H_ + +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h.inc" + +#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALTYPES_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index c558dc53cc7fa..c4d788cf8ed31 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -61,6 +61,7 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" @@ -131,6 +132,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { omp::OpenMPDialect, pdl::PDLDialect, pdl_interp::PDLInterpDialect, + polynomial::PolynomialDialect, quant::QuantizationDialect, ROCDL::ROCDLDialect, scf::SCFDialect, diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index b1ba5a3bc8817..a324ce7f9b19f 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon) add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) +add_subdirectory(Polynomial) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/lib/Dialect/Polynomial/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt new file mode 100644 index 0000000000000..7f5b3255d5d90 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRPolynomialDialect + Polynomial.cpp + PolynomialAttributes.cpp + PolynomialDialect.cpp + PolynomialOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Polynomial + + DEPENDS + MLIRPolynomialIncGen + MLIRPolynomialAttributesIncGen + MLIRBuiltinAttributesIncGen + + LINK_LIBS PUBLIC + MLIRSupport + MLIRDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp new file mode 100644 index 0000000000000..5916ffba78e24 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp @@ -0,0 +1,96 @@ +//===- Polynomial.cpp - MLIR storage type for static Polynomial -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" + +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace polynomial { + +FailureOr Polynomial::fromMonomials(ArrayRef monomials) { + // A polynomial's terms are canonically stored in order of increasing degree. + auto monomialsCopy = llvm::SmallVector(monomials); + std::sort(monomialsCopy.begin(), monomialsCopy.end()); + + // Ensure non-unique exponents are not present. Since we sorted the list by + // exponent, a linear scan of adjancent monomials suffices. + if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(), + [](const Monomial &lhs, const Monomial &rhs) { + return lhs.exponent == rhs.exponent; + }) != monomialsCopy.end()) { + return failure(); + } + + return Polynomial(monomialsCopy); +} + +Polynomial Polynomial::fromCoefficients(ArrayRef coeffs) { + llvm::SmallVector monomials; + auto size = coeffs.size(); + monomials.reserve(size); + for (size_t i = 0; i < size; i++) { + monomials.emplace_back(coeffs[i], i); + } + auto result = Polynomial::fromMonomials(monomials); + // Construction guarantees unique exponents, so the failure mode of + // fromMonomials can be bypassed. + assert(succeeded(result)); + return result.value(); +} + +void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator, + ::llvm::StringRef exponentiation) const { + bool first = true; + for (const Monomial &term : terms) { + if (first) { + first = false; + } else { + os << separator; + } + std::string coeffToPrint; + if (term.coefficient == 1 && term.exponent.uge(1)) { + coeffToPrint = ""; + } else { + llvm::SmallString<16> coeffString; + term.coefficient.toStringSigned(coeffString); + coeffToPrint = coeffString.str(); + } + + if (term.exponent == 0) { + os << coeffToPrint; + } else if (term.exponent == 1) { + os << coeffToPrint << "x"; + } else { + llvm::SmallString<16> expString; + term.exponent.toStringSigned(expString); + os << coeffToPrint << "x" << exponentiation << expString; + } + } +} + +void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); } + +std::string Polynomial::toIdentifier() const { + std::string result; + llvm::raw_string_ostream os(result); + print(os, "_", ""); + return os.str(); +} + +unsigned Polynomial::getDegree() const { + return terms.back().exponent.getZExtValue(); +} + +} // namespace polynomial +} // namespace mlir diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp new file mode 100644 index 0000000000000..ee09c73bb3c4a --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -0,0 +1,213 @@ +//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" + +namespace mlir { +namespace polynomial { + +void PolynomialAttr::print(AsmPrinter &p) const { + p << '<'; + p << getPolynomial(); + p << '>'; +} + +/// Try to parse a monomial. If successful, populate the fields of the outparam +/// `monomial` with the results, and the `variable` outparam with the parsed +/// variable name. Sets shouldParseMore to true if the monomial is followed by +/// a '+'. +ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, + llvm::StringRef &variable, bool &isConstantTerm, + bool &shouldParseMore) { + APInt parsedCoeff(apintBitWidth, 1); + auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff); + monomial.coefficient = parsedCoeff; + + isConstantTerm = false; + shouldParseMore = false; + + // A + indicates it's a constant term with more to go, as in `1 + x`. + if (succeeded(parser.parseOptionalPlus())) { + // If no coefficient was parsed, and there's a +, then it's effectively + // parsing an empty string. + if (!parsedCoeffResult.has_value()) { + return failure(); + } + monomial.exponent = APInt(apintBitWidth, 0); + isConstantTerm = true; + shouldParseMore = true; + return success(); + } + + // A monomial can be a trailing constant term, as in `x + 1`. + if (failed(parser.parseOptionalKeyword(&variable))) { + // If neither a coefficient nor a variable was found, then it's effectively + // parsing an empty string. + if (!parsedCoeffResult.has_value()) { + return failure(); + } + + monomial.exponent = APInt(apintBitWidth, 0); + isConstantTerm = true; + return success(); + } + + // Parse exponentiation symbol as `**`. We can't use caret because it's + // reserved for basic block identifiers If no star is present, it's treated + // as a polynomial with exponent 1. + if (succeeded(parser.parseOptionalStar())) { + // If there's one * there must be two. + if (failed(parser.parseStar())) { + return failure(); + } + + // If there's a **, then the integer exponent is required. + APInt parsedExponent(apintBitWidth, 0); + if (failed(parser.parseInteger(parsedExponent))) { + parser.emitError(parser.getCurrentLocation(), + "found invalid integer exponent"); + return failure(); + } + + monomial.exponent = parsedExponent; + } else { + monomial.exponent = APInt(apintBitWidth, 1); + } + + if (succeeded(parser.parseOptionalPlus())) { + shouldParseMore = true; + } + return success(); +} + +Attribute PolynomialAttr::parse(AsmParser &parser, Type type) { + if (failed(parser.parseLess())) + return {}; + + llvm::SmallVector monomials; + llvm::StringSet<> variables; + + while (true) { + Monomial parsedMonomial; + llvm::StringRef parsedVariableRef; + bool isConstantTerm; + bool shouldParseMore; + if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef, + isConstantTerm, shouldParseMore))) { + parser.emitError(parser.getCurrentLocation(), "expected a monomial"); + return {}; + } + + if (!isConstantTerm) { + std::string parsedVariable = parsedVariableRef.str(); + variables.insert(parsedVariable); + } + monomials.push_back(parsedMonomial); + + if (shouldParseMore) + continue; + + if (succeeded(parser.parseOptionalGreater())) { + break; + } + parser.emitError( + parser.getCurrentLocation(), + "expected + and more monomials, or > to end polynomial attribute"); + return {}; + } + + if (variables.size() > 1) { + std::string vars = llvm::join(variables.keys(), ", "); + parser.emitError( + parser.getCurrentLocation(), + "polynomials must have one indeterminate, but there were multiple: " + + vars); + } + + auto result = Polynomial::fromMonomials(monomials); + if (failed(result)) { + parser.emitError(parser.getCurrentLocation()) + << "parsed polynomial must have unique exponents among monomials"; + return {}; + } + return PolynomialAttr::get(parser.getContext(), result.value()); +} + +void RingAttr::print(AsmPrinter &p) const { + p << "#polynomial.ring(); + if (!iType) { + parser.emitError(parser.getCurrentLocation(), + "coefficientType must specify an integer type"); + return {}; + } + APInt coefficientModulus(iType.getWidth(), 0); + auto result = parser.parseInteger(coefficientModulus); + if (failed(result)) { + parser.emitError(parser.getCurrentLocation(), + "invalid coefficient modulus"); + return {}; + } + coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus); + + if (failed(parser.parseComma())) + return {}; + } + + PolynomialAttr polyAttr = nullptr; + if (succeeded(parser.parseKeyword("polynomialModulus"))) { + if (failed(parser.parseEqual())) + return {}; + + PolynomialAttr attr; + if (failed(parser.parseAttribute(attr))) + return {}; + polyAttr = attr; + } + + if (failed(parser.parseGreater())) + return {}; + + return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr, + polyAttr); +} + +} // namespace polynomial +} // namespace mlir diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp new file mode 100644 index 0000000000000..a672a59b8a465 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp @@ -0,0 +1,41 @@ +//===- PolynomialDialect.cpp - Polynomial dialect ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" + +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h" +#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::polynomial; + +#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc" +#define GET_OP_CLASSES +#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc" + +void PolynomialDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp new file mode 100644 index 0000000000000..96c59a28b8fdc --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -0,0 +1,15 @@ +//===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" + +using namespace mlir; +using namespace mlir::polynomial; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc" diff --git a/mlir/test/Dialect/Polynomial/attributes.mlir b/mlir/test/Dialect/Polynomial/attributes.mlir new file mode 100644 index 0000000000000..3973ae3944335 --- /dev/null +++ b/mlir/test/Dialect/Polynomial/attributes.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s --split-input-file --verify-diagnostics + +#my_poly = #polynomial.polynomial +// expected-error@below {{polynomials must have one indeterminate, but there were multiple: x, y}} +#ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{expected integer value}} +// expected-error@below {{expected a monomial}} +// expected-error@below {{found invalid integer exponent}} +#my_poly = #polynomial.polynomial<5 + x**f> +#ring1 = #polynomial.ring + +// ----- + +#my_poly = #polynomial.polynomial<5 + x**2 + 3x**2> +// expected-error@below {{parsed polynomial must have unique exponents among monomials}} +#ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{expected + and more monomials, or > to end polynomial attribute}} +#my_poly = #polynomial.polynomial<5 + x**2 7> +#ring1 = #polynomial.ring + +// ----- + +// expected-error@below {{expected a monomial}} +#my_poly = #polynomial.polynomial<5 + x**2 +> +#ring1 = #polynomial.ring + + +// ----- + +#my_poly = #polynomial.polynomial<5 + x**2> +// expected-error@below {{coefficientType must specify an integer type}} +#ring1 = #polynomial.ring + +// ----- + +#my_poly = #polynomial.polynomial<5 + x**2> +// expected-error@below {{expected integer value}} +// expected-error@below {{invalid coefficient modulus}} +#ring1 = #polynomial.ring diff --git a/mlir/test/Dialect/Polynomial/types.mlir b/mlir/test/Dialect/Polynomial/types.mlir new file mode 100644 index 0000000000000..64b74d9d36bb1 --- /dev/null +++ b/mlir/test/Dialect/Polynomial/types.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: func @test_types +// CHECK-SAME: !polynomial.polynomial< +// CHECK-SAME: #polynomial.ring< +// CHECK-SAME: coefficientType=i32, +// CHECK-SAME: coefficientModulus=2837465 : i32, +// CHECK-SAME: polynomialModulus=#polynomial.polynomial<1 + x**1024>>> +#my_poly = #polynomial.polynomial<1 + x**1024> +#ring1 = #polynomial.ring +!ty = !polynomial.polynomial<#ring1> +func.func @test_types(%0: !ty) -> !ty { + return %0 : !ty +} + + +// CHECK-LABEL: func @test_non_x_variable_64_bit +// CHECK-SAME: !polynomial.polynomial< +// CHECK-SAME: #polynomial.ring< +// CHECK-SAME: coefficientType=i64, +// CHECK-SAME: coefficientModulus=2837465 : i64, +// CHECK-SAME: polynomialModulus=#polynomial.polynomial<2 + 4x + x**3>>> +#my_poly_2 = #polynomial.polynomial +#ring2 = #polynomial.ring +!ty2 = !polynomial.polynomial<#ring2> +func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 { + return %0 : !ty2 +} + + +// CHECK-LABEL: func @test_linear_poly +// CHECK-SAME: !polynomial.polynomial< +// CHECK-SAME: #polynomial.ring< +// CHECK-SAME: coefficientType=i32, +// CHECK-SAME: coefficientModulus=12 : i32, +// CHECK-SAME: polynomialModulus=#polynomial.polynomial<4x>> +#my_poly_3 = #polynomial.polynomial<4x> +#ring3 = #polynomial.ring +!ty3 = !polynomial.polynomial<#ring3> +func.func @test_linear_poly(%0: !ty3) -> !ty3 { + return %0 : !ty3 +}