From f75e9dc9a747788d068c4c4bf74d3e502a2a732c Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Wed, 5 Nov 2025 14:39:29 +0000 Subject: [PATCH 1/2] [RTG] Add SetAttr and TupleAttr --- include/circt/Dialect/RTG/IR/RTGAttributes.td | 46 +++++++++ lib/Dialect/RTG/IR/RTGAttributes.cpp | 98 +++++++++++++++++++ lib/Dialect/RTG/IR/RTGDialect.cpp | 2 +- test/Dialect/RTG/IR/basic.mlir | 18 ++++ test/Dialect/RTG/IR/errors.mlir | 21 ++++ 5 files changed, 184 insertions(+), 1 deletion(-) diff --git a/include/circt/Dialect/RTG/IR/RTGAttributes.td b/include/circt/Dialect/RTG/IR/RTGAttributes.td index ac4b7514e707..ae2468f2022a 100644 --- a/include/circt/Dialect/RTG/IR/RTGAttributes.td +++ b/include/circt/Dialect/RTG/IR/RTGAttributes.td @@ -51,6 +51,52 @@ def AnyContextAttr : RTGAttrDef<"AnyContext", [ let assemblyFormat = ""; } +// The accessor type is a pointer instead of a reference because references +// cannot be reassigned (in the attribute 'construct' function). +class DenseSetParameter + : AttrOrTypeParameter<"const ::llvm::DenseSet<" # setOf # "> *", desc> { + let allocator = [{ + $_dst = new ($_allocator.allocate>()) + DenseSet<}] # setOf # [{>($_self->begin(), $_self->end()); + }]; + let cppStorageType = "::llvm::DenseSet<" # setOf # ">"; +} + +def SetAttr : RTGAttrDef<"Set", [ + DeclareAttrInterfaceMethods, +]> { + let summary = "an unordered set of elements"; + + let parameters = (ins + AttributeSelfTypeParameter<"", "rtg::SetType">:$type, + DenseSetParameter<"TypedAttr", "elements">:$elements); + + let builders = [ + AttrBuilderWithInferredContext< + (ins "rtg::SetType":$type, + "const ::llvm::DenseSet<::mlir::TypedAttr> *":$elements), [{ + return $_get(type.getContext(), type, elements); + }]> + ]; + + let mnemonic = "set"; + let hasCustomAssemblyFormat = true; + + let genVerifyDecl = true; +} + +def TupleAttr : RTGAttrDef<"Tuple", [ + DeclareAttrInterfaceMethods, +]> { + let summary = "a tuple"; + + let parameters = + (ins ArrayRefParameter<"::mlir::TypedAttr", "elements">:$elements); + + let mnemonic = "tuple"; + let assemblyFormat = "`<` $elements `>`"; +} + //===----------------------------------------------------------------------===// // Attributes for ISA targets //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/RTG/IR/RTGAttributes.cpp b/lib/Dialect/RTG/IR/RTGAttributes.cpp index 43e545773a9b..d689d3b2d12d 100644 --- a/lib/Dialect/RTG/IR/RTGAttributes.cpp +++ b/lib/Dialect/RTG/IR/RTGAttributes.cpp @@ -15,6 +15,104 @@ using namespace circt; using namespace rtg; +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +namespace llvm { +template +// NOLINTNEXTLINE(readability-identifier-naming) +llvm::hash_code hash_value(const DenseSet &set) { + // TODO: improve collision resistance + unsigned hash = 0; + for (auto element : set) + hash ^= element; + return hash; +} +} // namespace llvm + +//===----------------------------------------------------------------------===// +// SetAttr +//===----------------------------------------------------------------------===// + +LogicalResult +SetAttr::verify(llvm::function_ref emitError, + rtg::SetType type, const DenseSet *elements) { + + // check that all elements have the right type + // iterating over the set is fine here because the iteration order is not + // visible to the outside (it would not be fine to print the earliest invalid + // element) + if (!llvm::all_of(*elements, [&](auto element) { + return element.getType() == type.getElementType(); + })) { + return emitError() << "all elements must be of the set element type " + << type.getElementType(); + } + + return success(); +} + +Attribute SetAttr::parse(AsmParser &odsParser, Type odsType) { + DenseSet elements; + Type elementType; + if (odsParser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater, + [&]() { + TypedAttr element; + if (odsParser.parseAttribute(element)) + return failure(); + elements.insert(element); + elementType = element.getType(); + return success(); + })) + return {}; + + auto setType = llvm::dyn_cast_or_null(odsType); + if (odsType && !setType) { + odsParser.emitError(odsParser.getNameLoc()) + << "type must be a an '!rtg.set' type"; + return {}; + } + + if (!setType && elements.empty()) { + odsParser.emitError(odsParser.getNameLoc()) + << "type must be explicitly provided: cannot infer set element type " + "from empty set"; + return {}; + } + + if (!setType && !elements.empty()) + setType = SetType::get(elementType); + + return SetAttr::getChecked( + odsParser.getEncodedSourceLoc(odsParser.getNameLoc()), + odsParser.getContext(), setType, &elements); +} + +void SetAttr::print(AsmPrinter &odsPrinter) const { + odsPrinter << "<"; + // Sort elements lexicographically by their printed string representation + SmallVector sortedElements; + for (auto element : *getElements()) { + std::string &elementStr = sortedElements.emplace_back(); + llvm::raw_string_ostream elementOS(elementStr); + element.print(elementOS); + } + llvm::sort(sortedElements); + llvm::interleaveComma(sortedElements, odsPrinter); + odsPrinter << ">"; +} + +//===----------------------------------------------------------------------===// +// TupleAttr +//===----------------------------------------------------------------------===// + +Type TupleAttr::getType() const { + SmallVector elementTypes(llvm::map_range( + getElements(), [](auto element) { return element.getType(); })); + return TupleType::get(getContext(), elementTypes); +} + //===----------------------------------------------------------------------===// // ImmediateAttr //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/RTG/IR/RTGDialect.cpp b/lib/Dialect/RTG/IR/RTGDialect.cpp index f5ef736a85bb..f40e82a8fa5b 100644 --- a/lib/Dialect/RTG/IR/RTGDialect.cpp +++ b/lib/Dialect/RTG/IR/RTGDialect.cpp @@ -44,7 +44,7 @@ void RTGDialect::initialize() { /// constant value. Otherwise, it should return null on failure. Operation *RTGDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (auto attr = dyn_cast(value)) + if (auto attr = dyn_cast(value)) if (type == attr.getType()) return ConstantOp::create(builder, loc, attr); diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index 22ad22a22637..17fd6307dad0 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -15,6 +15,24 @@ rtg.test @constants() { // CHECK-NEXT: rtg.isa.space [[V0]] rtg.isa.space %1 + + // CHECK-NEXT: rtg.constant #rtg.set<> : !rtg.set + rtg.constant #rtg.set<> : !rtg.set + + // Test that set elements are printed in lexicographic order + // CHECK-NEXT: rtg.constant #rtg.set<#rtgtest.a0 : !rtgtest.ireg, #rtgtest.a1 : !rtgtest.ireg, #rtgtest.a2 : !rtgtest.ireg> : !rtg.set + rtg.constant #rtg.set<#rtgtest.a1, #rtgtest.a0, #rtgtest.a2> : !rtg.set + + // Test set type inference + // CHECK-NEXT: rtg.constant #rtg.set<0 : i32, 1 : i32, 2 : i32> : !rtg.set + rtg.constant #rtg.set<1 : i32, 0 : i32, 2 : i32> + + // CHECK-NEXT: rtg.constant #rtg.tuple<0 : i32, 1 : index> : !rtg.tuple + rtg.constant #rtg.tuple<0 : i32, 1 : index> : !rtg.tuple + + // Test set type inference + // CHECK-NEXT: rtg.constant #rtg.tuple<0 : i32, 1 : index> : !rtg.tuple + rtg.constant #rtg.tuple<0 : i32, 1 : index> } // CHECK-LABEL: rtg.sequence @ranomizedSequenceType diff --git a/test/Dialect/RTG/IR/errors.mlir b/test/Dialect/RTG/IR/errors.mlir index 386f197349d6..02b3803c1599 100644 --- a/test/Dialect/RTG/IR/errors.mlir +++ b/test/Dialect/RTG/IR/errors.mlir @@ -329,3 +329,24 @@ rtg.test @concatImmediateNonImmediateOperand() { // expected-error @below {{all operands must be of immediate type}} %1 = rtg.isa.concat_immediate %0 : index } + +// ----- + +rtg.test @setAttrNotSetType() { + // expected-error @below {{type must be a an '!rtg.set' type}} + rtg.constant #rtg.set<> : !rtg.bag +} + +// ----- + +rtg.test @setAttrExplicitTypeRequired() { + // expected-error @below {{type must be explicitly provided: cannot infer set element type from empty set}} + rtg.constant #rtg.set<> +} + +// ----- + +rtg.test @setAttrExplicitTypeRequired() { + // expected-error @below {{all elements must be of the set element type 'i32'}} + rtg.constant #rtg.set<0 : index> : !rtg.set +} From 8ec9ab52e752d50e5df329ce10f8bd17b1138be9 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Wed, 5 Nov 2025 14:58:56 +0000 Subject: [PATCH 2/2] [RTG] Add folders for most set operations --- include/circt/Dialect/RTG/IR/RTGOps.td | 9 ++ lib/Dialect/RTG/IR/RTGOps.cpp | 113 ++++++++++++++++++++++ test/Dialect/RTG/IR/canonicalization.mlir | 60 ++++++++++++ 3 files changed, 182 insertions(+) diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index ac525a8a9d97..9841c767b9ef 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -299,6 +299,7 @@ def SetCreateOp : RTGOp<"set_create", [Pure, SameTypeOperands]> { let hasCustomAssemblyFormat = 1; let hasVerifier = 1; + let hasFolder = 1; } def SetSelectRandomOp : RTGOp<"set_select_random", [ @@ -332,6 +333,8 @@ def SetDifferenceOp : RTGOp<"set_difference", [ let assemblyFormat = [{ $original `,` $diff `:` qualified(type($output)) attr-dict }]; + + let hasFolder = 1; } def SetUnionOp : RTGOp<"set_union", [ @@ -349,6 +352,8 @@ def SetUnionOp : RTGOp<"set_union", [ let assemblyFormat = [{ $sets `:` qualified(type($result)) attr-dict }]; + + let hasFolder = 1; } def SetSizeOp : RTGOp<"set_size", [Pure]> { @@ -360,6 +365,8 @@ def SetSizeOp : RTGOp<"set_size", [Pure]> { let assemblyFormat = [{ $set `:` qualified(type($set)) attr-dict }]; + + let hasFolder = 1; } def SetCartesianProductOp : RTGOp<"set_cartesian_product", [ @@ -391,6 +398,8 @@ def SetCartesianProductOp : RTGOp<"set_cartesian_product", [ let results = (outs SetType:$result); let assemblyFormat = "$inputs `:` qualified(type($inputs)) attr-dict"; + + let hasFolder = 1; } def SetConvertToBagOp : RTGOp<"set_convert_to_bag", [ diff --git a/lib/Dialect/RTG/IR/RTGOps.cpp b/lib/Dialect/RTG/IR/RTGOps.cpp index cd84ee136d59..9e4a810a67b2 100644 --- a/lib/Dialect/RTG/IR/RTGOps.cpp +++ b/lib/Dialect/RTG/IR/RTGOps.cpp @@ -900,6 +900,119 @@ OpFoldResult SliceImmediateOp::fold(FoldAdaptor adaptor) { return {}; } +//===----------------------------------------------------------------------===// +// SetCreateOp +//===----------------------------------------------------------------------===// + +OpFoldResult SetCreateOp::fold(FoldAdaptor adaptor) { + DenseSet elements; + for (auto attr : adaptor.getElements()) { + auto typedAttr = dyn_cast_or_null(attr); + if (!typedAttr) + return {}; + + elements.insert(typedAttr); + } + + return SetAttr::get(getType(), &elements); +} + +//===----------------------------------------------------------------------===// +// SetSizeOp +//===----------------------------------------------------------------------===// + +OpFoldResult SetSizeOp::fold(FoldAdaptor adaptor) { + auto setAttr = dyn_cast_or_null(adaptor.getSet()); + if (!setAttr) + return {}; + + return IntegerAttr::get(IndexType::get(getContext()), + setAttr.getElements()->size()); +} + +//===----------------------------------------------------------------------===// +// SetUnionOp +//===----------------------------------------------------------------------===// + +OpFoldResult SetUnionOp::fold(FoldAdaptor adaptor) { + // Fast track to make sure we're not computing the union of all sets but the + // last of the variadic operands is NULL. + if (llvm::any_of(adaptor.getSets(), [&](Attribute attr) { return !attr; })) + return {}; + + DenseSet res; + for (auto set : adaptor.getSets()) { + auto setAttr = dyn_cast(set); + if (!set) + return {}; + + for (auto element : *setAttr.getElements()) + res.insert(element); + } + + return SetAttr::get(getType(), &res); +} + +//===----------------------------------------------------------------------===// +// SetDifferenceOp +//===----------------------------------------------------------------------===// + +OpFoldResult SetDifferenceOp::fold(FoldAdaptor adaptor) { + auto original = dyn_cast_or_null(adaptor.getOriginal()); + auto diff = dyn_cast_or_null(adaptor.getDiff()); + if (!original || !diff) + return {}; + + DenseSet res(*original.getElements()); + for (auto element : *diff.getElements()) + res.erase(element); + + return SetAttr::get(getType(), &res); +} + +//===----------------------------------------------------------------------===// +// SetCartesianProductOp +//===----------------------------------------------------------------------===// + +OpFoldResult SetCartesianProductOp::fold(FoldAdaptor adaptor) { + // Fast track to make sure we're not computing the product of all sets but the + // last of the variadic operands is NULL. + if (llvm::any_of(adaptor.getInputs(), [&](Attribute attr) { return !attr; })) + return {}; + + DenseSet res; + SmallVector> tuples; + tuples.push_back({}); + + for (auto input : adaptor.getInputs()) { + auto setAttr = dyn_cast(input); + if (!setAttr) + return {}; + + DenseSet set(*setAttr.getElements()); + if (set.empty()) { + DenseSet empty; + return SetAttr::get(getType(), &empty); + } + + for (unsigned i = 0, e = tuples.size(); i < e; ++i) { + for (auto [k, el] : llvm::enumerate(set)) { + if (k == set.size() - 1) { + tuples[i].push_back(el); + continue; + } + tuples.push_back(tuples[i]); + tuples.back().push_back(el); + } + } + } + + for (auto &tup : tuples) + res.insert(TupleAttr::get(getContext(), tup)); + + return SetAttr::get(getType(), &res); +} + //===----------------------------------------------------------------------===// // TableGen generated logic. //===----------------------------------------------------------------------===// diff --git a/test/Dialect/RTG/IR/canonicalization.mlir b/test/Dialect/RTG/IR/canonicalization.mlir index 2bfcffd37aed..5090f94e3331 100644 --- a/test/Dialect/RTG/IR/canonicalization.mlir +++ b/test/Dialect/RTG/IR/canonicalization.mlir @@ -1,5 +1,10 @@ // RUN: circt-opt --canonicalize %s | FileCheck %s +func.func @dummy0(%arg0: !rtg.set) -> () {return} +func.func @dummy1(%arg0: index) -> () {return} +func.func @dummy2(%arg0: !rtg.set>) -> () {return} +func.func @dummy3(%arg0: !rtg.set>) -> () {return} + // CHECK-LABEL: @interleaveSequences rtg.test @interleaveSequences(seq0 = %seq0: !rtg.randomized_sequence) { // CHECK-NEXT: rtg.embed_sequence %seq0 @@ -19,3 +24,58 @@ rtg.target @immediates : !rtg.dict, imm1: !rtg.isa. // CHECK-NEXT: rtg.yield [[V0]], [[V1]] : rtg.yield %3, %4 : !rtg.isa.immediate<64>, !rtg.isa.immediate<2> } + +// CHECK-LABEL: @sets +rtg.test @sets() { + %idx0 = index.constant 0 + %idx1 = index.constant 1 + %set0 = rtg.constant #rtg.set<1 : index, 0 : index> : !rtg.set + %set1 = rtg.constant #rtg.set<1 : index, 2 : index> : !rtg.set + %set2 = rtg.constant #rtg.set<2 : index, 3 : index> : !rtg.set + %set3 = rtg.constant #rtg.set<> : !rtg.set + %set4 = rtg.constant #rtg.set<4 : i32, 5 : i32> : !rtg.set + %set5 = rtg.constant #rtg.set<6 : i64, 7 : i64> : !rtg.set + + // CHECK: [[SET:%.+]] = rtg.constant #rtg.set<0 : index, 1 : index> : !rtg.set + %0 = rtg.set_create %idx1, %idx0 : index + + // CHECK: [[SIZE:%.+]] = rtg.constant 2 : index + %size = rtg.set_size %0 : !rtg.set + + // CHECK: [[UNION:%.+]] = rtg.constant #rtg.set<0 : index, 1 : index, 2 : index, 3 : index> : !rtg.set + %union = rtg.set_union %set0, %set1, %set2 : !rtg.set + + // CHECK: [[DIFF:%.+]] = rtg.constant #rtg.set<0 : index> : !rtg.set + %diff = rtg.set_difference %set0, %set1 : !rtg.set + + // CHECK: [[PROD:%.+]] = rtg.constant #rtg.set< + // CHECK-SAME: #rtg.tuple<0 : index, 4 : i32, 6 : i64> : !rtg.tuple + // CHECK-SAME: #rtg.tuple<0 : index, 4 : i32, 7 : i64> : !rtg.tuple + // CHECK-SAME: #rtg.tuple<0 : index, 5 : i32, 6 : i64> : !rtg.tuple + // CHECK-SAME: #rtg.tuple<0 : index, 5 : i32, 7 : i64> : !rtg.tuple + // CHECK-SAME: #rtg.tuple<1 : index, 4 : i32, 6 : i64> : !rtg.tuple + // CHECK-SAME: #rtg.tuple<1 : index, 4 : i32, 7 : i64> : !rtg.tuple + // CHECK-SAME: #rtg.tuple<1 : index, 5 : i32, 6 : i64> : !rtg.tuple + // CHECK-SAME: #rtg.tuple<1 : index, 5 : i32, 7 : i64> : !rtg.tuple> + // CHECK-SAME: !rtg.set> + %prod0 = rtg.set_cartesian_product %set0, %set4, %set5 : !rtg.set, !rtg.set, !rtg.set + // CHECK: [[EMPTY:%.+]] = rtg.constant #rtg.set<> : !rtg.set> + %prod1 = rtg.set_cartesian_product %set0, %set4, %set3 : !rtg.set, !rtg.set, !rtg.set + // CHECK: [[SET2:%.+]] = rtg.constant #rtg.set<#rtg.tuple<0 : index> : !rtg.tuple, #rtg.tuple<1 : index> : !rtg.tuple> : !rtg.set> + %prod2 = rtg.set_cartesian_product %set0 : !rtg.set + + // CHECK: func.call @dummy0([[SET:%.+]]) + func.call @dummy0(%0) : (!rtg.set) -> () + // CHECK: func.call @dummy1([[SIZE:%.+]]) + func.call @dummy1(%size) : (index) -> () + // CHECK: func.call @dummy0([[UNION]]) + func.call @dummy0(%union) : (!rtg.set) -> () + // CHECK: func.call @dummy0([[DIFF]]) + func.call @dummy0(%diff) : (!rtg.set) -> () + // CHECK: func.call @dummy2([[PROD]]) + func.call @dummy2(%prod0) : (!rtg.set>) -> () + // CHECK: func.call @dummy2([[EMPTY]]) + func.call @dummy2(%prod1) : (!rtg.set>) -> () + // CHECK: func.call @dummy3([[SET2]]) : (!rtg.set>) -> () + func.call @dummy3(%prod2) : (!rtg.set>) -> () +}