Skip to content

Commit 6f0ff78

Browse files
committed
move duplicate exponent check to fromMonomials
1 parent b86febc commit 6f0ff78

File tree

4 files changed

+29
-20
lines changed

4 files changed

+29
-20
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
#ifndef MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
1010
#define MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
1111

12-
#include <utility>
13-
12+
#include "mlir/Support/LogicalResult.h"
1413
#include "mlir/Support/LLVM.h"
1514
#include "llvm/ADT/APInt.h"
1615
#include "llvm/ADT/ArrayRef.h"
@@ -76,7 +75,9 @@ class Polynomial {
7675

7776
explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms) {};
7877

79-
static Polynomial fromMonomials(ArrayRef<Monomial> monomials);
78+
// Returns a Polynomial from a list of monomials.
79+
// Fails if two monomials have the same exponent.
80+
static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
8081

8182
/// Returns a polynomial with coefficients given by `coeffs`. The value
8283
/// coeffs[i] is converted to a monomial with exponent i.

mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
1010

11-
#include "mlir/IR/MLIRContext.h"
11+
#include "mlir/Support/LogicalResult.h"
1212
#include "llvm/ADT/APInt.h"
1313
#include "llvm/ADT/SmallString.h"
1414
#include "llvm/ADT/SmallVector.h"
@@ -18,10 +18,20 @@
1818
namespace mlir {
1919
namespace polynomial {
2020

21-
Polynomial Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
21+
FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
2222
// A polynomial's terms are canonically stored in order of increasing degree.
2323
auto monomialsCopy = llvm::SmallVector<Monomial>(monomials);
2424
std::sort(monomialsCopy.begin(), monomialsCopy.end());
25+
26+
// Ensure non-unique exponents are not present. Since we sorted the list by
27+
// exponent, a linear scan of adjancent monomials suffices.
28+
if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
29+
[](const Monomial &lhs, const Monomial &rhs) {
30+
return lhs.exponent == rhs.exponent;
31+
}) != monomialsCopy.end()) {
32+
return failure();
33+
}
34+
2535
return Polynomial(monomialsCopy);
2636
}
2737

@@ -32,7 +42,11 @@ Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
3242
for (size_t i = 0; i < size; i++) {
3343
monomials.emplace_back(coeffs[i], i);
3444
}
35-
return Polynomial::fromMonomials(monomials);
45+
auto result = Polynomial::fromMonomials(monomials);
46+
// Construction guarantees unique exponents, so the failure mode of
47+
// fromMonomials can be bypassed.
48+
assert(succeeded(result));
49+
return result.value();
3650
}
3751

3852
void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator,

mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
9797

9898
llvm::SmallVector<Monomial> monomials;
9999
llvm::StringSet<> variables;
100-
llvm::DenseSet<APInt> exponents;
101100

102101
while (true) {
103102
Monomial parsedMonomial;
@@ -116,16 +115,6 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
116115
}
117116
monomials.push_back(parsedMonomial);
118117

119-
if (llvm::is_contained(exponents, parsedMonomial.exponent)) {
120-
llvm::SmallString<16> coeffString;
121-
parsedMonomial.exponent.toStringSigned(coeffString);
122-
parser.emitError(parser.getCurrentLocation())
123-
<< "at most one monomial may have exponent " << coeffString
124-
<< ", but found multiple";
125-
return {};
126-
}
127-
exponents.insert(parsedMonomial.exponent);
128-
129118
if (shouldParseMore)
130119
continue;
131120

@@ -146,8 +135,13 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
146135
vars);
147136
}
148137

149-
Polynomial poly = Polynomial::fromMonomials(monomials);
150-
return PolynomialAttr::get(parser.getContext(), poly);
138+
auto result = Polynomial::fromMonomials(monomials);
139+
if (failed(result)) {
140+
parser.emitError(parser.getCurrentLocation())
141+
<< "parsed polynomial must have unique exponents among monomials";
142+
return {};
143+
}
144+
return PolynomialAttr::get(parser.getContext(), result.value());
151145
}
152146

153147
void RingAttr::print(AsmPrinter &p) const {

mlir/test/Dialect/Polynomial/attributes.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
// -----
1616

17-
// expected-error@below {{at most one monomial may have exponent 2, but found multiple}}
1817
#my_poly = #polynomial.polynomial<5 + x**2 + 3x**2>
18+
// expected-error@below {{parsed polynomial must have unique exponents among monomials}}
1919
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
2020

2121
// -----

0 commit comments

Comments
 (0)