Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,52 @@ def AnyContextAttr : RTGAttrDef<"AnyContext", [
let assemblyFormat = "";
}

// The accessor type is a pointer instead of a reference because references
// cannot be reassigned (in the attribute 'construct' function).
class DenseSetParameter<string setOf, string desc = "">
: AttrOrTypeParameter<"const ::llvm::DenseSet<" # setOf # "> *", desc> {
let allocator = [{
$_dst = new ($_allocator.allocate<DenseSet<}] # setOf # [{>>())
DenseSet<}] # setOf # [{>($_self->begin(), $_self->end());
}];
let cppStorageType = "::llvm::DenseSet<" # setOf # ">";
}

def SetAttr : RTGAttrDef<"Set", [
DeclareAttrInterfaceMethods<TypedAttrInterface>,
]> {
let summary = "an unordered set of elements";

let parameters = (ins
AttributeSelfTypeParameter<"", "rtg::SetType">:$type,
DenseSetParameter<"TypedAttr", "elements">:$elements);

let builders = [
AttrBuilderWithInferredContext<
(ins "rtg::SetType":$type,
"const ::llvm::DenseSet<::mlir::TypedAttr> *":$elements), [{
return $_get(type.getContext(), type, elements);
}]>
];

let mnemonic = "set";
let hasCustomAssemblyFormat = true;

let genVerifyDecl = true;
}

def TupleAttr : RTGAttrDef<"Tuple", [
DeclareAttrInterfaceMethods<TypedAttrInterface>,
]> {
let summary = "a tuple";

let parameters =
(ins ArrayRefParameter<"::mlir::TypedAttr", "elements">:$elements);

let mnemonic = "tuple";
let assemblyFormat = "`<` $elements `>`";
}

//===----------------------------------------------------------------------===//
// Attributes for ISA targets
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def SetCreateOp : RTGOp<"set_create", [Pure, SameTypeOperands]> {

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let hasFolder = 1;
}

def SetSelectRandomOp : RTGOp<"set_select_random", [
Expand Down Expand Up @@ -332,6 +333,8 @@ def SetDifferenceOp : RTGOp<"set_difference", [
let assemblyFormat = [{
$original `,` $diff `:` qualified(type($output)) attr-dict
}];

let hasFolder = 1;
}

def SetUnionOp : RTGOp<"set_union", [
Expand All @@ -349,6 +352,8 @@ def SetUnionOp : RTGOp<"set_union", [
let assemblyFormat = [{
$sets `:` qualified(type($result)) attr-dict
}];

let hasFolder = 1;
}

def SetSizeOp : RTGOp<"set_size", [Pure]> {
Expand All @@ -360,6 +365,8 @@ def SetSizeOp : RTGOp<"set_size", [Pure]> {
let assemblyFormat = [{
$set `:` qualified(type($set)) attr-dict
}];

let hasFolder = 1;
}

def SetCartesianProductOp : RTGOp<"set_cartesian_product", [
Expand Down Expand Up @@ -391,6 +398,8 @@ def SetCartesianProductOp : RTGOp<"set_cartesian_product", [
let results = (outs SetType:$result);

let assemblyFormat = "$inputs `:` qualified(type($inputs)) attr-dict";

let hasFolder = 1;
}

def SetConvertToBagOp : RTGOp<"set_convert_to_bag", [
Expand Down
98 changes: 98 additions & 0 deletions lib/Dialect/RTG/IR/RTGAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,104 @@
using namespace circt;
using namespace rtg;

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

namespace llvm {
template <typename T>
// NOLINTNEXTLINE(readability-identifier-naming)
llvm::hash_code hash_value(const DenseSet<T> &set) {
// TODO: improve collision resistance
unsigned hash = 0;
for (auto element : set)
hash ^= element;
return hash;
}
} // namespace llvm

//===----------------------------------------------------------------------===//
// SetAttr
//===----------------------------------------------------------------------===//

LogicalResult
SetAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
rtg::SetType type, const DenseSet<TypedAttr> *elements) {

// check that all elements have the right type
// iterating over the set is fine here because the iteration order is not
// visible to the outside (it would not be fine to print the earliest invalid
// element)
if (!llvm::all_of(*elements, [&](auto element) {
return element.getType() == type.getElementType();
})) {
return emitError() << "all elements must be of the set element type "
<< type.getElementType();
}

return success();
}

Attribute SetAttr::parse(AsmParser &odsParser, Type odsType) {
DenseSet<TypedAttr> elements;
Type elementType;
if (odsParser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
[&]() {
TypedAttr element;
if (odsParser.parseAttribute(element))
return failure();
elements.insert(element);
elementType = element.getType();
return success();
}))
return {};

auto setType = llvm::dyn_cast_or_null<SetType>(odsType);
if (odsType && !setType) {
odsParser.emitError(odsParser.getNameLoc())
<< "type must be a an '!rtg.set' type";
return {};
}

if (!setType && elements.empty()) {
odsParser.emitError(odsParser.getNameLoc())
<< "type must be explicitly provided: cannot infer set element type "
"from empty set";
return {};
}

if (!setType && !elements.empty())
setType = SetType::get(elementType);

return SetAttr::getChecked(
odsParser.getEncodedSourceLoc(odsParser.getNameLoc()),
odsParser.getContext(), setType, &elements);
}

void SetAttr::print(AsmPrinter &odsPrinter) const {
odsPrinter << "<";
// Sort elements lexicographically by their printed string representation
SmallVector<std::string> sortedElements;
for (auto element : *getElements()) {
std::string &elementStr = sortedElements.emplace_back();
llvm::raw_string_ostream elementOS(elementStr);
element.print(elementOS);
}
llvm::sort(sortedElements);
llvm::interleaveComma(sortedElements, odsPrinter);
odsPrinter << ">";
}

//===----------------------------------------------------------------------===//
// TupleAttr
//===----------------------------------------------------------------------===//

Type TupleAttr::getType() const {
SmallVector<Type> elementTypes(llvm::map_range(
getElements(), [](auto element) { return element.getType(); }));
return TupleType::get(getContext(), elementTypes);
}

//===----------------------------------------------------------------------===//
// ImmediateAttr
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/RTG/IR/RTGDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void RTGDialect::initialize() {
/// constant value. Otherwise, it should return null on failure.
Operation *RTGDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (auto attr = dyn_cast<ImmediateAttr>(value))
if (auto attr = dyn_cast<TypedAttr>(value))
if (type == attr.getType())
return ConstantOp::create(builder, loc, attr);

Expand Down
113 changes: 113 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,119 @@ OpFoldResult SliceImmediateOp::fold(FoldAdaptor adaptor) {
return {};
}

//===----------------------------------------------------------------------===//
// SetCreateOp
//===----------------------------------------------------------------------===//

OpFoldResult SetCreateOp::fold(FoldAdaptor adaptor) {
DenseSet<TypedAttr> elements;
for (auto attr : adaptor.getElements()) {
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (!typedAttr)
return {};

elements.insert(typedAttr);
}

return SetAttr::get(getType(), &elements);
}

//===----------------------------------------------------------------------===//
// SetSizeOp
//===----------------------------------------------------------------------===//

OpFoldResult SetSizeOp::fold(FoldAdaptor adaptor) {
auto setAttr = dyn_cast_or_null<SetAttr>(adaptor.getSet());
if (!setAttr)
return {};

return IntegerAttr::get(IndexType::get(getContext()),
setAttr.getElements()->size());
}

//===----------------------------------------------------------------------===//
// SetUnionOp
//===----------------------------------------------------------------------===//

OpFoldResult SetUnionOp::fold(FoldAdaptor adaptor) {
// Fast track to make sure we're not computing the union of all sets but the
// last of the variadic operands is NULL.
if (llvm::any_of(adaptor.getSets(), [&](Attribute attr) { return !attr; }))
return {};

DenseSet<TypedAttr> res;
for (auto set : adaptor.getSets()) {
auto setAttr = dyn_cast<SetAttr>(set);
if (!set)
return {};

for (auto element : *setAttr.getElements())
res.insert(element);
}

return SetAttr::get(getType(), &res);
}

//===----------------------------------------------------------------------===//
// SetDifferenceOp
//===----------------------------------------------------------------------===//

OpFoldResult SetDifferenceOp::fold(FoldAdaptor adaptor) {
auto original = dyn_cast_or_null<SetAttr>(adaptor.getOriginal());
auto diff = dyn_cast_or_null<SetAttr>(adaptor.getDiff());
if (!original || !diff)
return {};

DenseSet<TypedAttr> res(*original.getElements());
for (auto element : *diff.getElements())
res.erase(element);

return SetAttr::get(getType(), &res);
}

//===----------------------------------------------------------------------===//
// SetCartesianProductOp
//===----------------------------------------------------------------------===//

OpFoldResult SetCartesianProductOp::fold(FoldAdaptor adaptor) {
// Fast track to make sure we're not computing the product of all sets but the
// last of the variadic operands is NULL.
if (llvm::any_of(adaptor.getInputs(), [&](Attribute attr) { return !attr; }))
return {};

DenseSet<TypedAttr> res;
SmallVector<SmallVector<TypedAttr>> tuples;
tuples.push_back({});

for (auto input : adaptor.getInputs()) {
auto setAttr = dyn_cast<SetAttr>(input);
if (!setAttr)
return {};

DenseSet<TypedAttr> set(*setAttr.getElements());
if (set.empty()) {
DenseSet<TypedAttr> empty;
return SetAttr::get(getType(), &empty);
}

for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
for (auto [k, el] : llvm::enumerate(set)) {
if (k == set.size() - 1) {
tuples[i].push_back(el);
continue;
}
tuples.push_back(tuples[i]);
tuples.back().push_back(el);
}
}
}

for (auto &tup : tuples)
res.insert(TupleAttr::get(getContext(), tup));

return SetAttr::get(getType(), &res);
}

//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@ rtg.test @constants() {

// CHECK-NEXT: rtg.isa.space [[V0]]
rtg.isa.space %1

// CHECK-NEXT: rtg.constant #rtg.set<> : !rtg.set<i32>
rtg.constant #rtg.set<> : !rtg.set<i32>

// Test that set elements are printed in lexicographic order
// CHECK-NEXT: rtg.constant #rtg.set<#rtgtest.a0 : !rtgtest.ireg, #rtgtest.a1 : !rtgtest.ireg, #rtgtest.a2 : !rtgtest.ireg> : !rtg.set<!rtgtest.ireg>
rtg.constant #rtg.set<#rtgtest.a1, #rtgtest.a0, #rtgtest.a2> : !rtg.set<!rtgtest.ireg>

// Test set type inference
// CHECK-NEXT: rtg.constant #rtg.set<0 : i32, 1 : i32, 2 : i32> : !rtg.set<i32>
rtg.constant #rtg.set<1 : i32, 0 : i32, 2 : i32>

// CHECK-NEXT: rtg.constant #rtg.tuple<0 : i32, 1 : index> : !rtg.tuple<i32, index>
rtg.constant #rtg.tuple<0 : i32, 1 : index> : !rtg.tuple<i32, index>

// Test set type inference
// CHECK-NEXT: rtg.constant #rtg.tuple<0 : i32, 1 : index> : !rtg.tuple<i32, index>
rtg.constant #rtg.tuple<0 : i32, 1 : index>
}

// CHECK-LABEL: rtg.sequence @ranomizedSequenceType
Expand Down
Loading