Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mapped_domain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ mlir::LogicalResult MappedDomain::ResolveUnification(
mlir::Operation *dim_op = dimension.value.defining_op().GetDuplicatedOp();
if (isa<SairPlaceholderOp>(dim_op)) return mlir::success();

if (constraint.isa<MappingNoneExpr, MappingUnknownExpr>()) {
if (llvm::isa<MappingNoneExpr, MappingUnknownExpr>(constraint)) {
// If the dimension is new, extend the domain.
constraint = MappingDimExpr::get(domain_.size(), context());
assert(dimension.mapping.IsSurjective());
Expand Down
82 changes: 45 additions & 37 deletions sair_attributes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ llvm::SmallBitVector MappingExpr::DependencyMask(int domain_size) const {
bool MappingExpr::HasNoneExprs() const {
bool has_none_exprs = false;
Walk([&](MappingExpr sub_expr) {
has_none_exprs |= sub_expr.isa<MappingNoneExpr>();
has_none_exprs |= llvm::isa<MappingNoneExpr>(sub_expr);
});
return has_none_exprs;
}

bool MappingExpr::HasUnknownExprs() const {
bool has_unknown_exprs = false;
Walk([&](MappingExpr sub_expr) {
has_unknown_exprs |= sub_expr.isa<MappingUnknownExpr>();
has_unknown_exprs |= llvm::isa<MappingUnknownExpr>(sub_expr);
});
return has_unknown_exprs;
}
Expand All @@ -109,10 +109,10 @@ int MappingExpr::MinDomainSize() const {
// expression is `?` or `none`. Returns `nullptr` if unification fails.
static MappingExpr ResolveNoneAndUnknownUnification(MappingExpr lhs,
MappingExpr rhs) {
if (lhs.isa<MappingNoneExpr>()) return rhs;
if (rhs.isa<MappingNoneExpr>()) return lhs;
if (lhs.isa<MappingUnknownExpr>()) return rhs;
if (rhs.isa<MappingUnknownExpr>()) return lhs;
if (llvm::isa<MappingNoneExpr>(lhs)) return rhs;
if (llvm::isa<MappingNoneExpr>(rhs)) return lhs;
if (llvm::isa<MappingUnknownExpr>(lhs)) return rhs;
if (llvm::isa<MappingUnknownExpr>(rhs)) return lhs;
return MappingExpr();
}

Expand Down Expand Up @@ -383,7 +383,7 @@ mlir::LogicalResult MappingStripeExpr::SetInverse(
MappingExpr MappingStripeExpr::FindInInverse(
llvm::ArrayRef<MappingExpr> inverse) const {
auto operand_inverse = operand().FindInInverse(inverse);
if (operand_inverse.isa<MappingUnknownExpr, MappingNoneExpr>()) {
if (llvm::isa<MappingUnknownExpr, MappingNoneExpr>(operand_inverse)) {
return operand_inverse;
}
auto unstripe_expr = llvm::cast<MappingUnStripeExpr>(operand_inverse);
Expand Down Expand Up @@ -545,7 +545,7 @@ MappingExpr MappingUnStripeExpr::Unify(

// If the last operand is `none` or `?`, we can replace it by an arbitrary
// number of operands.
if (min_operands.back().isa<MappingNoneExpr, MappingUnknownExpr>()) {
if (llvm::isa<MappingNoneExpr, MappingUnknownExpr>(min_operands.back())) {
min_operands = min_operands.drop_back();
min_factors = min_factors.drop_back();
}
Expand All @@ -568,7 +568,7 @@ MappingExpr MappingUnStripeExpr::FindInInverse(
MappingExpr operand_inverse;
for (int i = 0, e = operands().size(); i < e; ++i) {
operand_inverse = operands()[i].FindInInverse(inverse);
if (operand_inverse.isa<MappingUnknownExpr, MappingNoneExpr>()) continue;
if (llvm::isa<MappingUnknownExpr, MappingNoneExpr>(operand_inverse)) continue;
return llvm::cast<MappingStripeExpr>(operand_inverse).operand();
}
// Unstripe has at least one operand.
Expand Down Expand Up @@ -797,7 +797,7 @@ MappingAttr MappingAttr::MakeSurjective() const {
new_exprs.reserve(size());
for (MappingExpr expr : Dimensions()) {
MappingExpr new_expr = expr.Map([&](MappingExpr sub_expr) -> MappingExpr {
if (!sub_expr.isa<MappingNoneExpr>()) return sub_expr;
if (!llvm::isa<MappingNoneExpr>(sub_expr)) return sub_expr;
return MappingDimExpr::get(num_dimensions++, getContext());
});
new_exprs.push_back(new_expr);
Expand All @@ -810,7 +810,7 @@ MappingAttr MappingAttr::MakeFullySpecified() const {
auto new_exprs =
llvm::to_vector<4>(llvm::map_range(Dimensions(), [&](auto expr) {
return expr.Map([&](MappingExpr sub_expr) -> MappingExpr {
return sub_expr.isa<MappingUnknownExpr>() ? none : sub_expr;
return llvm::isa<MappingUnknownExpr>(sub_expr) ? none : sub_expr;
});
}));
return MappingAttr::get(getContext(), UseDomainSize(), new_exprs);
Expand Down Expand Up @@ -946,8 +946,8 @@ MappingAttr MappingAttr::UnifyUnknownExprs(MappingAttr other) const {
for (auto [lhs, rhs] : llvm::zip(Dimensions(), other.Dimensions())) {
MappingExpr unified =
lhs.Unify(rhs, [](MappingExpr sub_lhs, MappingExpr sub_rhs) {
if (sub_lhs.isa<MappingUnknownExpr>()) return sub_rhs;
if (sub_rhs.isa<MappingUnknownExpr>()) return sub_lhs;
if (llvm::isa<MappingUnknownExpr>(sub_lhs)) return sub_rhs;
if (llvm::isa<MappingUnknownExpr>(sub_rhs)) return sub_lhs;
return MappingExpr();
});
if (unified == nullptr) return nullptr;
Expand Down Expand Up @@ -1236,7 +1236,7 @@ static DomainShapeDim StripeAccessedShape(MappingStripeExpr expr,
static DomainShapeDim UnStripeAccessedShape(MappingUnStripeExpr expr,
DomainShapeDim inner_shape,
MappingAttr inverted_mapping) {
if (inner_shape.type().isa<DynRangeType>()) return inner_shape;
if (llvm::isa<DynRangeType>(inner_shape.type())) return inner_shape;
auto type = llvm::cast<StaticRangeType>(inner_shape.type());
int new_step = type.getStep() / expr.factors().front();
return DomainShapeDim(
Expand Down Expand Up @@ -1460,10 +1460,10 @@ bool LoopAttr::classof(mlir::Attribute attr) {
if (!derived) return false;

auto name = derived.get("name");
if (!name.isa_and_nonnull<mlir::StringAttr>()) return false;
if (!llvm::isa_and_nonnull<mlir::StringAttr>(name)) return false;

auto iter = derived.get("iter");
if (!iter.isa_and_nonnull<sair::MappingExpr>()) return false;
if (!llvm::isa_and_nonnull<sair::MappingExpr>(iter)) return false;

auto unroll = derived.get("unroll");
if (!unroll) return derived.size() == 2;
Expand All @@ -1481,23 +1481,25 @@ mlir::StringAttr LoopAttr::name() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto name = derived.get("name");
assert(name && "attribute not found.");
assert(name.isa<mlir::StringAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::StringAttr>(name) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::StringAttr>(name);
}

MappingExpr LoopAttr::iter() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto iter = derived.get("iter");
assert(iter && "attribute not found.");
assert(iter.isa<MappingExpr>() && "incorrect Attribute type found.");
assert(llvm::isa<MappingExpr>(iter) && "incorrect Attribute type found.");
return llvm::cast<MappingExpr>(iter);
}

mlir::IntegerAttr LoopAttr::unroll() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto unroll = derived.get("unroll");
if (!unroll) return nullptr;
assert(unroll.isa<mlir::IntegerAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::IntegerAttr>(unroll) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::IntegerAttr>(unroll);
}

Expand Down Expand Up @@ -1531,19 +1533,19 @@ bool BufferAttr::classof(mlir::Attribute attr) {
int num_absent_attrs = 0;

auto space = derived.get("space");
if (!space.isa_and_nonnull<mlir::StringAttr>()) return false;
if (!llvm::isa_and_nonnull<mlir::StringAttr>(space)) return false;

auto name = derived.get("name");
if (!name) {
++num_absent_attrs;
} else if (!name.isa<mlir::StringAttr>()) {
} else if (!llvm::isa<mlir::StringAttr>(name)) {
return false;
}

auto layout = derived.get("layout");
if (!layout) {
++num_absent_attrs;
} else if (!layout.isa<NamedMappingAttr>()) {
} else if (!llvm::isa<NamedMappingAttr>(layout)) {
return false;
}

Expand All @@ -1554,23 +1556,25 @@ mlir::StringAttr BufferAttr::space() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto space = derived.get("space");
assert(space && "attribute not found.");
assert(space.isa<mlir::StringAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::StringAttr>(space) && "incorrect Attribute type found.");
return llvm::cast<mlir::StringAttr>(space);
}

mlir::StringAttr BufferAttr::name() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto name = derived.get("name");
if (!name) return nullptr;
assert(name.isa<mlir::StringAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::StringAttr>(name) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::StringAttr>(name);
}

NamedMappingAttr BufferAttr::layout() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto layout = derived.get("layout");
if (!layout) return nullptr;
assert(layout.isa<NamedMappingAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<NamedMappingAttr>(layout) &&
"incorrect Attribute type found.");
return llvm::cast<NamedMappingAttr>(layout);
}

Expand Down Expand Up @@ -1640,7 +1644,7 @@ bool DecisionsAttr::classof(mlir::Attribute attr) {
auto loop_nest_attr = llvm::dyn_cast<mlir::ArrayAttr>(loop_nest);
if (!loop_nest_attr) return false;
if (llvm::any_of(loop_nest_attr, [](mlir::Attribute attr) {
return !attr.isa_and_nonnull<LoopAttr>();
return !llvm::isa_and_nonnull<LoopAttr>(attr);
})) {
return false;
}
Expand All @@ -1649,21 +1653,21 @@ bool DecisionsAttr::classof(mlir::Attribute attr) {
auto storage = derived.get("storage");
if (!storage) {
++num_absent_attrs;
} else if (!storage.isa<mlir::ArrayAttr>()) {
} else if (!llvm::isa<mlir::ArrayAttr>(storage)) {
return false;
}

auto expansion = derived.get("expansion");
if (!expansion) {
++num_absent_attrs;
} else if (!expansion.isa<mlir::StringAttr>()) {
} else if (!llvm::isa<mlir::StringAttr>(expansion)) {
return false;
}

auto copy_of = derived.get("copy_of");
if (!copy_of) {
++num_absent_attrs;
} else if (!copy_of.isa<CopyAttr, InstanceAttr, mlir::UnitAttr>()) {
} else if (!llvm::isa<CopyAttr, InstanceAttr, mlir::UnitAttr>(copy_of)) {
return false;
}

Expand All @@ -1673,8 +1677,8 @@ bool DecisionsAttr::classof(mlir::Attribute attr) {
} else {
auto operands_attr = llvm::dyn_cast<mlir::ArrayAttr>(operands);
if (llvm::any_of(operands_attr, [](mlir::Attribute attr) {
return !attr.isa_and_nonnull<CopyAttr, InstanceAttr,
mlir::UnitAttr>();
return !llvm::isa_and_nonnull<CopyAttr, InstanceAttr, mlir::UnitAttr>(
attr);
})) {
return false;
}
Expand All @@ -1687,7 +1691,7 @@ mlir::IntegerAttr DecisionsAttr::sequence() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto sequence = derived.get("sequence");
if (!sequence) return nullptr;
assert(sequence.isa<mlir::IntegerAttr>() &&
assert(llvm::isa<mlir::IntegerAttr>(sequence) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::IntegerAttr>(sequence);
}
Expand All @@ -1696,23 +1700,25 @@ mlir::ArrayAttr DecisionsAttr::loop_nest() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto loop_nest = derived.get("loop_nest");
if (!loop_nest) return nullptr;
assert(loop_nest.isa<mlir::ArrayAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::ArrayAttr>(loop_nest) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::ArrayAttr>(loop_nest);
}

mlir::ArrayAttr DecisionsAttr::storage() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto storage = derived.get("storage");
if (!storage) return nullptr;
assert(storage.isa<mlir::ArrayAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::ArrayAttr>(storage) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::ArrayAttr>(storage);
}

mlir::StringAttr DecisionsAttr::expansion() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto expansion = derived.get("expansion");
if (!expansion) return nullptr;
assert(expansion.isa<mlir::StringAttr>() &&
assert(llvm::isa<mlir::StringAttr>(expansion) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::StringAttr>(expansion);
}
Expand All @@ -1721,15 +1727,17 @@ mlir::Attribute DecisionsAttr::copy_of() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto copy_of = derived.get("copy_of");
if (!copy_of) return nullptr;
assert(copy_of.isa<mlir::Attribute>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::Attribute>(copy_of) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::Attribute>(copy_of);
}

mlir::ArrayAttr DecisionsAttr::operands() const {
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
auto operands = derived.get("operands");
if (!operands) return nullptr;
assert(operands.isa<mlir::ArrayAttr>() && "incorrect Attribute type found.");
assert(llvm::isa<mlir::ArrayAttr>(operands) &&
"incorrect Attribute type found.");
return llvm::cast<mlir::ArrayAttr>(operands);
}

Expand Down
19 changes: 8 additions & 11 deletions sair_base.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def SairEmptyDomainShapeAttr :
def SairResultDomainShapeAttr :
DerivedAttr<"DomainShapeAttr", [{
mlir::Type type = getOperation()->getResult(0).getType();
return type.cast<ShapedType>().Shape();
return llvm::cast<ShapedType>(type).Shape();
}]> {
let convertFromStorage = [{$_self}];
}
Expand Down Expand Up @@ -248,7 +248,7 @@ def SairValue : Type<CPred<"isa<ValueType>($_self)">, "value">;

// Predicate that checks the element type of a Sair value.
class SairElementTypePred<Type type>
: SubstLeaves<"$_self", "$_self.cast<ValueType>().ElementType()",
: SubstLeaves<"$_self", "llvm::cast<ValueType>($_self).ElementType()",
type.predicate>;

// Type constraint for Sair values with a specific element type.
Expand Down Expand Up @@ -420,7 +420,7 @@ def SairOpInterface : OpInterface<"SairOp"> {
"Returns lowering decisions for the given operation instance",
"DecisionsAttr", "GetDecisions", (ins "int":$instance), [{}], [{
mlir::ArrayAttr instances = *$_op.getInstances();
return instances.getValue()[instance].cast<DecisionsAttr>();
return llvm::cast<DecisionsAttr>(instances.getValue()[instance]);
}]
>,
InterfaceMethod<
Expand Down Expand Up @@ -504,24 +504,23 @@ def SairValueProducerOp : OpInterface<"ValueProducerOp"> {
"llvm::ArrayRef<mlir::Attribute>", "GetCopies", (ins "int":$result), [{}], [{
auto all_copies = $_op.getCopiesAttr();
if (all_copies == nullptr) return {};
return all_copies.getValue()[result]
.template cast<mlir::ArrayAttr>().getValue();
return llvm::cast<mlir::ArrayAttr>(all_copies.getValue()[result]).getValue();
}]>,
InterfaceMethod<
"Indicates if the operation has any copy set in its `copies` attribute`",
"bool", "HasCopies", (ins), [{}], [{
auto all_copies = $_op.getCopiesAttr();
if (all_copies == nullptr) return false;
return llvm::any_of(all_copies.getValue(), [](mlir::Attribute attr) {
return !attr.cast<mlir::ArrayAttr>().empty();
return !llvm::cast<mlir::ArrayAttr>(attr).empty();
});
}]>,
InterfaceMethod<
"Set decisions for the given copy of the given result.",
"void", "SetCopy",
(ins "int":$result, "int":$copy, "DecisionsAttr":$decisions), [{}], [{
auto all_copies = llvm::to_vector<4>(*$_op.getCopies());
auto result_copies_attr = all_copies[result].template cast<mlir::ArrayAttr>();
auto result_copies_attr = llvm::cast<mlir::ArrayAttr>(all_copies[result]);
auto result_copies = llvm::to_vector<4>(result_copies_attr.getValue());

result_copies[copy] = decisions;
Expand Down Expand Up @@ -582,10 +581,8 @@ def SairFromToMemRefOp : OpInterface<"FromToMemRefOp"> {
InterfaceMethod<"Buffer name", "llvm::StringRef", "getBufferName">,
InterfaceMethod<"Memref type", "mlir::MemRefType", "MemRefType", (ins),
[{}], [{
return $_op.MemRef()
.GetType()
.ElementType()
.template cast<mlir::MemRefType>();
return llvm::cast<mlir::MemRefType>(
$_op.MemRef().GetType().ElementType());
}]>,
InterfaceMethod<"Mapping from value domain to layout", "MappingAttr",
"Layout", (ins), [{}], [{
Expand Down
6 changes: 3 additions & 3 deletions sair_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ namespace {
// Accepts a ray stream so that it can be used from different flavors of
// printers.
void PrintMappingExpr(MappingExpr expr, llvm::raw_ostream &os) {
if (expr.isa<MappingNoneExpr>()) {
if (llvm::isa<MappingNoneExpr>(expr)) {
os << MappingNoneExpr::kAttrName;
} else if (expr.isa<MappingUnknownExpr>()) {
} else if (llvm::isa<MappingUnknownExpr>(expr)) {
os << MappingUnknownExpr::kAttrName;
} else if (auto dim_expr = llvm::dyn_cast<MappingDimExpr>(expr)) {
os << "d" << dim_expr.dimension();
Expand Down Expand Up @@ -325,7 +325,7 @@ void PrintDomainShapeDim(const DomainShapeDim &dimension,
mlir::DialectAsmPrinter &os) {
if (auto static_range = llvm::dyn_cast<StaticRangeType>(dimension.type())) {
Print(static_range, os);
} else if (dimension.type().isa<DynRangeType>()) {
} else if (llvm::isa<DynRangeType>(dimension.type())) {
os << DynRangeType::Name();
} else {
llvm_unreachable("unsupported dimension type");
Expand Down
Loading