Skip to content

Commit 0f861d3

Browse files
committed
polynomial: add dialect shell, attrs, and types
1 parent 754b93e commit 0f861d3

22 files changed

+946
-0
lines changed

mlir/include/mlir/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_subdirectory(OpenACCMPCommon)
2727
add_subdirectory(OpenMP)
2828
add_subdirectory(PDL)
2929
add_subdirectory(PDLInterp)
30+
add_subdirectory(Polynomial)
3031
add_subdirectory(Quant)
3132
add_subdirectory(SCF)
3233
add_subdirectory(Shape)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
set(LLVM_TARGET_DEFINITIONS PolynomialOps.td)
2+
mlir_tablegen(PolynomialDialect.cpp.inc -gen-dialect-defs -dialect=polynomial)
3+
mlir_tablegen(PolynomialDialect.h.inc -gen-dialect-decls -dialect=polynomial)
4+
add_public_tablegen_target(MLIRPolynomialDialectIncGen)
5+
6+
mlir_tablegen(PolynomialAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=polynomial)
7+
mlir_tablegen(PolynomialAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=polynomial)
8+
mlir_tablegen(PolynomialOps.cpp.inc -gen-op-defs)
9+
mlir_tablegen(PolynomialOps.h.inc -gen-op-decls)
10+
mlir_tablegen(PolynomialTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=polynomial)
11+
mlir_tablegen(PolynomialTypes.h.inc -gen-typedef-decls -typedefs-dialect=polynomial)
12+
add_public_tablegen_target(MLIRPolynomialAttributesIncGen)
13+
add_public_tablegen_target(MLIRPolynomialOpsIncGen)
14+
add_public_tablegen_target(MLIRPolynomialTypesIncGen)
15+
add_dependencies(mlir-headers MLIRPolynomialOpsIncGen)
16+
17+
add_mlir_doc(PolynoialOps PolynoialOps Dialects/ -gen-dialect-doc -dialect polynomial)
18+
add_mlir_doc(PolynomialAttributes PolynomialAttributes Dialects/ -gen-attrdef-doc)
19+
add_mlir_doc(PolynomialTypes PolynomialTypes Dialects/ -gen-typedef-doc)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
//===- Polynomial.h - A storage class for polynomial types --------*- C++-*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
10+
#define INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
11+
12+
#include <utility>
13+
14+
#include "mlir/Support/LLVM.h"
15+
#include "llvm/ADT/APInt.h"
16+
#include "llvm/ADT/DenseMapInfo.h"
17+
#include "llvm/ADT/Hashing.h"
18+
19+
namespace mlir {
20+
21+
class MLIRContext;
22+
23+
namespace polynomial {
24+
25+
// This restricts statically defined polynomials to have at most 64-bit
26+
// coefficients. This may be relaxed in the future, but it seems unlikely one
27+
// would want to specify 128-bit polynomials statically in the source code.
28+
constexpr unsigned apintBitWidth = 64;
29+
30+
namespace detail {
31+
struct PolynomialStorage;
32+
} // namespace detail
33+
34+
class Monomial {
35+
public:
36+
Monomial(int64_t coeff, uint64_t expo)
37+
: coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
38+
39+
Monomial(APInt coeff, APInt expo)
40+
: coefficient(std::move(coeff)), exponent(std::move(expo)) {}
41+
42+
Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
43+
44+
bool operator==(const Monomial &other) const {
45+
return other.coefficient == coefficient && other.exponent == exponent;
46+
}
47+
bool operator!=(const Monomial &other) const {
48+
return other.coefficient != coefficient || other.exponent != exponent;
49+
}
50+
51+
/// Monomials are ordered by exponent.
52+
bool operator<(const Monomial &other) const {
53+
return (exponent.ult(other.exponent));
54+
}
55+
56+
// Prints polynomial to 'os'.
57+
void print(raw_ostream &os) const;
58+
59+
friend ::llvm::hash_code hash_value(Monomial arg);
60+
61+
public:
62+
APInt coefficient;
63+
64+
// Always unsigned
65+
APInt exponent;
66+
};
67+
68+
/// A single-variable polynomial with integer coefficients. Polynomials are
69+
/// immutable and uniqued.
70+
///
71+
/// Eg: x^1024 + x + 1
72+
///
73+
/// The symbols used as the polynomial's indeterminate don't matter, so long as
74+
/// it is used consistently throughout the polynomial.
75+
class Polynomial {
76+
public:
77+
using ImplType = detail::PolynomialStorage;
78+
79+
constexpr Polynomial() = default;
80+
explicit Polynomial(ImplType *terms) : terms(terms) {}
81+
82+
static Polynomial fromMonomials(ArrayRef<Monomial> monomials,
83+
MLIRContext *context);
84+
/// Returns a polynomial with coefficients given by `coeffs`
85+
static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs,
86+
MLIRContext *context);
87+
88+
MLIRContext *getContext() const;
89+
90+
explicit operator bool() const { return terms != nullptr; }
91+
bool operator==(Polynomial other) const { return other.terms == terms; }
92+
bool operator!=(Polynomial other) const { return !(other.terms == terms); }
93+
94+
// Prints polynomial to 'os'.
95+
void print(raw_ostream &os) const;
96+
void print(raw_ostream &os, const std::string &separator,
97+
const std::string &exponentiation) const;
98+
void dump() const;
99+
100+
// Prints polynomial so that it can be used as a valid identifier
101+
std::string toIdentifier() const;
102+
103+
// A polynomial's terms are canonically stored in order of increasing degree.
104+
ArrayRef<Monomial> getTerms() const;
105+
106+
unsigned getDegree() const;
107+
108+
friend ::llvm::hash_code hash_value(Polynomial arg);
109+
110+
private:
111+
ImplType *terms{nullptr};
112+
};
113+
114+
// Make Polynomial hashable.
115+
inline ::llvm::hash_code hash_value(Polynomial arg) {
116+
return ::llvm::hash_value(arg.terms);
117+
}
118+
119+
inline ::llvm::hash_code hash_value(Monomial arg) {
120+
return ::llvm::hash_value(arg.coefficient) ^ ::llvm::hash_value(arg.exponent);
121+
}
122+
123+
inline raw_ostream &operator<<(raw_ostream &os, Polynomial polynomial) {
124+
polynomial.print(os);
125+
return os;
126+
}
127+
128+
} // namespace polynomial
129+
} // namespace mlir
130+
131+
namespace llvm {
132+
133+
// Polynomials hash just like pointers
134+
template <>
135+
struct DenseMapInfo<mlir::polynomial::Polynomial> {
136+
static mlir::polynomial::Polynomial getEmptyKey() {
137+
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
138+
return mlir::polynomial::Polynomial(
139+
static_cast<mlir::polynomial::Polynomial::ImplType *>(pointer));
140+
}
141+
static mlir::polynomial::Polynomial getTombstoneKey() {
142+
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
143+
return mlir::polynomial::Polynomial(
144+
static_cast<mlir::polynomial::Polynomial::ImplType *>(pointer));
145+
}
146+
static unsigned getHashValue(mlir::polynomial::Polynomial val) {
147+
return mlir::polynomial::hash_value(val);
148+
}
149+
static bool isEqual(mlir::polynomial::Polynomial lhs,
150+
mlir::polynomial::Polynomial rhs) {
151+
return lhs == rhs;
152+
}
153+
};
154+
155+
} // namespace llvm
156+
157+
#endif // INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- PolynomialAttributes.h - Attributes for the Polynomial dialect -*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
#ifndef INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_
10+
#define INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_
11+
12+
#include "Polynomial.h"
13+
#include "PolynomialDialect.h"
14+
15+
#define GET_ATTRDEF_CLASSES
16+
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h.inc"
17+
18+
#endif // INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//===- PolynomialAttributes.td - Attribute definitions for the polynomial dialect ------*- tablegen -*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef POLYNOMIAL_ATTRIBUTES
9+
#define POLYNOMIAL_ATTRIBUTES
10+
11+
include "PolynomialDialect.td"
12+
include "mlir/IR/BuiltinAttributes.td"
13+
include "mlir/IR/OpBase.td"
14+
15+
class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
16+
: AttrDef<Polynomial_Dialect, name, traits> {
17+
let mnemonic = attrMnemonic;
18+
}
19+
20+
def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
21+
let summary = "An attribute containing a single-variable polynomial.";
22+
let description = [{
23+
#poly = #polynomial.poly<x**1024 + 1>
24+
}];
25+
26+
let parameters = (ins "Polynomial":$polynomial);
27+
28+
let builders = [
29+
AttrBuilderWithInferredContext<(ins "Polynomial":$polynomial), [{
30+
return $_get(polynomial.getContext(), polynomial);
31+
}]>
32+
];
33+
34+
let skipDefaultBuilders = 1;
35+
let hasCustomAssemblyFormat = 1;
36+
}
37+
38+
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
39+
let summary = "An attribute specifying a polynomial ring.";
40+
let description = [{
41+
A ring describes the domain in which polynomial arithmetic occurs. The ring
42+
attribute in `polynomial` represents the more specific case of polynomials
43+
with a single indeterminate; whose coefficients can be represented by
44+
another MLIR type (`ctype`); and, if the coefficient type is integral,
45+
whose coefficients are taken modulo some statically known modulus (`cmod`).
46+
47+
Additionally, a polynomial ring can specify an _ideal_, which converts
48+
polynomial arithmetic to the analogue of modular integer arithmetic, where
49+
each polynomial is represented as its remainder when dividing by the
50+
modulus. For single-variable polynomials, an "ideal" is always specificed
51+
via a single polynomial, which we call `polynomialModulus`.
52+
53+
An expressive example is polynomials with i32 coefficients, whose
54+
coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of
55+
`x**1024 - 1`.
56+
57+
```
58+
#poly_mod = #polynomial.polynomial<-1 + x**1024>
59+
#ring = #polynomial.ring<ctype=i32, cmod=4294967291, ideal=#poly_mod>
60+
61+
%0 = ... : polynomial.polynomial<#ring>
62+
```
63+
64+
In this case, the value of a polynomial is always ``converted'' to a
65+
canonical form by applying repeated reductions by setting `x**1024 = 1`
66+
and simplifying.
67+
68+
The coefficient and polynomial modulus parameters are optional, and the
69+
coefficient modulus is only allowed if the coefficient type is integral.
70+
}];
71+
72+
let parameters = (ins
73+
Builtin_TypeAttr: $coefficientType,
74+
OptionalParameter<"std::optional<IntegerAttr>">: $coefficientModulus,
75+
OptionalParameter<"std::optional<PolynomialAttr>">: $polynomialModulus
76+
);
77+
78+
let hasCustomAssemblyFormat = 1;
79+
}
80+
81+
#endif // POLYNOMIAL_ATTRIBUTES
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===- PolynomialDialect.h - The Polynomial dialect -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_
9+
#define INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_
10+
11+
#include "mlir/IR/Builders.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/Dialect.h"
14+
#include "mlir/IR/DialectImplementation.h"
15+
16+
// Generated headers (block clang-format from messing up order)
17+
#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h.inc"
18+
19+
#endif // INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- PolynomialDialect.td - Dialect definition for the polynomial dialect ------*- tablegen -*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef POLYNOMIAL_DIALECT
9+
#define POLYNOMIAL_DIALECT
10+
11+
include "mlir/IR/OpBase.td"
12+
13+
def Polynomial_Dialect : Dialect {
14+
let name = "polynomial";
15+
let cppNamespace = "::mlir::polynomial";
16+
let description = [{
17+
The Polynomial dialect defines single-variable polynomial types and
18+
operations.
19+
20+
The simplest use of `polynomial` is to represent mathematical operations in
21+
a polynomial ring `R[x]`, where `R` is another MLIR type like `i32`.
22+
23+
More generally, this dialect supports representing polynomial operations in a
24+
quotient ring `R[X]/(f(x))` for some statically fixed polynomial `f(x)`.
25+
Two polyomials `p(x), q(x)` are considered equal in this ring if they have the
26+
same remainder when dividing by `f(x)`. When a modulus is given, ring operations
27+
are performed with reductions modulo `f(x)` and relative to the coefficient ring
28+
`R`.
29+
30+
Examples:
31+
32+
```mlir
33+
// A constant polynomial in a ring with i32 coefficients and no polynomial modulus
34+
#ring = #polynomial.ring<ctype=i32>
35+
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
36+
37+
// A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
38+
#modulus = #polynomial.polynomial<1 + x**1024>
39+
#ring = #polynomial.ring<ctype=i32, ideal=#modulus>
40+
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
41+
42+
// A constant polynomial in a ring with i32 coefficients, with a polynomial
43+
// modulus of (x^1024 + 1) and a coefficient modulus of 17.
44+
#modulus = #polynomial.polynomial<1 + x**1024>
45+
#ring = #polynomial.ring<ctype=i32, cmod=17, ideal=#modulus>
46+
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
47+
```
48+
}];
49+
50+
let useDefaultTypePrinterParser = 1;
51+
let useDefaultAttributePrinterParser = 1;
52+
}
53+
54+
#endif // POLYNOMIAL_DIALECT
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- PolynomialOps.h - Ops for the Polynomial dialect -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_
9+
#define INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_
10+
11+
#include "PolynomialDialect.h"
12+
#include "PolynomialTypes.h"
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
#include "mlir/IR/Dialect.h"
16+
#include "mlir/Interfaces/InferTypeOpInterface.h"
17+
18+
#define GET_OP_CLASSES
19+
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h.inc"
20+
21+
#endif // INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_

0 commit comments

Comments
 (0)