From bd6feaaeca337be4030777d97ce18fd667164c90 Mon Sep 17 00:00:00 2001 From: Stefan Stipanovic Date: Tue, 4 Mar 2025 21:32:37 +0100 Subject: [PATCH 1/4] Value traits support Co-authored-by: Valentin Churavy --- example/ExampleDialect.td | 11 +++ include/llvm-dialects/Dialect/Dialect.td | 55 +++++++---- include/llvm-dialects/TableGen/Dialects.h | 4 +- include/llvm-dialects/TableGen/Traits.h | 7 +- lib/TableGen/Dialects.cpp | 6 +- lib/TableGen/GenDialect.cpp | 6 +- lib/TableGen/Operations.cpp | 41 +++++++++ lib/TableGen/Traits.cpp | 92 +++++++++++++++---- test/example/generated/ExampleDialect.cpp.inc | 78 ++++++++++------ test/example/generated/ExampleDialect.h.inc | 2 +- 10 files changed, 231 insertions(+), 71 deletions(-) diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index 4c2e507..358ad7c 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -141,6 +141,13 @@ def Add32Op : ExampleOp<"add32", [Memory<[]>, NoUnwind, WillReturn]> { let results = (outs I32:$result); let arguments = (ins I32:$lhs, I32:$rhs, AttrI32:$extra); + let value_traits = [ + (NoCapture $lhs), + (NoCapture $rhs), + (NoUndef $lhs), + (NoUndef $result), + ]; + let summary = "add two numbers, and a little extra"; let description = [{ For those times when you want a little extra, this operation adds two @@ -152,6 +159,10 @@ def SizeOfOp : ExampleOp<"sizeof", [Memory<[]>, NoUnwind, WillReturn]> { let results = (outs I64:$result); let arguments = (ins type:$sizeof_type); + let value_traits = [ + (NoCapture $sizeof_type), + ]; + let summary = "size of a given type"; let description = [{ Returns the store size of the given type in bytes. diff --git a/include/llvm-dialects/Dialect/Dialect.td b/include/llvm-dialects/Dialect/Dialect.td index dcea24a..1a1147f 100644 --- a/include/llvm-dialects/Dialect/Dialect.td +++ b/include/llvm-dialects/Dialect/Dialect.td @@ -322,29 +322,42 @@ multiclass AttrEnum { /// Traits generally map to llvm::Attributes. // ============================================================================ -class Trait; +class TraitProperty; +def FnTrait : TraitProperty; +def ParamTrait : TraitProperty; +def RetTrait : TraitProperty; -class LlvmEnumAttributeTrait : Trait { +class Trait P> { + list Properties = P; +} + +class LlvmEnumAttributeTrait P> : Trait

{ string llvmEnum = llvmEnum_; } -def NoUnwind : LlvmEnumAttributeTrait<"NoUnwind">; -def WillReturn : LlvmEnumAttributeTrait<"WillReturn">; -def NoReturn : LlvmEnumAttributeTrait<"NoReturn">; -def NoRecurse : LlvmEnumAttributeTrait<"NoRecurse">; -def NoSync : LlvmEnumAttributeTrait<"NoSync">; -def NoFree : LlvmEnumAttributeTrait<"NoFree">; -def MustProgress : LlvmEnumAttributeTrait<"MustProgress">; -def NoCallback : LlvmEnumAttributeTrait<"NoCallback">; -def NoDuplicate : LlvmEnumAttributeTrait<"NoDuplicate">; -def NoBuiltin : LlvmEnumAttributeTrait<"NoBuiltin">; -def Builtin : LlvmEnumAttributeTrait<"Builtin">; -def InlineHint : LlvmEnumAttributeTrait<"InlineHint">; -def AlwaysInline : LlvmEnumAttributeTrait<"AlwaysInline">; -def Cold : LlvmEnumAttributeTrait<"Cold">; -def Hot : LlvmEnumAttributeTrait<"Hot">; -def Convergent : LlvmEnumAttributeTrait<"Convergent">; -def Speculatable : LlvmEnumAttributeTrait<"Speculatable">; +def NoUnwind : LlvmEnumAttributeTrait<"NoUnwind", [FnTrait]>; +def WillReturn : LlvmEnumAttributeTrait<"WillReturn", [FnTrait]>; +def NoReturn : LlvmEnumAttributeTrait<"NoReturn", [FnTrait]>; +def NoRecurse : LlvmEnumAttributeTrait<"NoRecurse", [FnTrait]>; +def NoSync : LlvmEnumAttributeTrait<"NoSync", [FnTrait]>; +def NoFree : LlvmEnumAttributeTrait<"NoFree", [FnTrait]>; +def MustProgress : LlvmEnumAttributeTrait<"MustProgress", [FnTrait]>; +def NoCallback : LlvmEnumAttributeTrait<"NoCallback", [FnTrait]>; +def NoDuplicate : LlvmEnumAttributeTrait<"NoDuplicate", [FnTrait]>; +def NoBuiltin : LlvmEnumAttributeTrait<"NoBuiltin", [FnTrait]>; +def Builtin : LlvmEnumAttributeTrait<"Builtin", [FnTrait]>; +def InlineHint : LlvmEnumAttributeTrait<"InlineHint", [FnTrait]>; +def AlwaysInline : LlvmEnumAttributeTrait<"AlwaysInline", [FnTrait]>; +def Cold : LlvmEnumAttributeTrait<"Cold", [FnTrait]>; +def Hot : LlvmEnumAttributeTrait<"Hot", [FnTrait]>; +def Convergent : LlvmEnumAttributeTrait<"Convergent", [FnTrait]>; +def Speculatable : LlvmEnumAttributeTrait<"Speculatable", [FnTrait]>; + +def NoCapture : LlvmEnumAttributeTrait<"NoCapture", [ParamTrait]>; +def ReadOnly : LlvmEnumAttributeTrait<"ReadOnly", [ParamTrait]>; + +def NoUndef : LlvmEnumAttributeTrait<"NoUndef", [ParamTrait, RetTrait]>; +def NonNull : LlvmEnumAttributeTrait<"NonNull", [ParamTrait, RetTrait]>; /// Represent the LLVM `memory(...)` attribute as the OR (or union) of memory /// effects. An empty effects list means the operation does not access memory @@ -358,7 +371,7 @@ def Speculatable : LlvmEnumAttributeTrait<"Speculatable">; /// Example: `Memory<[(ref), (mod ArgMem, InaccessibleMem)]>` means the /// operation may read from any kind of memory and write to argument and /// inaccessible memory. -class Memory effects_> : Trait { +class Memory effects_> : Trait<[FnTrait]> { list effects = effects_; } @@ -412,6 +425,8 @@ class Op traits_> { dag arguments = ?; dag results = ?; + list value_traits = []; + list verifier = []; string summary = ?; diff --git a/include/llvm-dialects/TableGen/Dialects.h b/include/llvm-dialects/TableGen/Dialects.h index 60317d8..6bb6ffd 100644 --- a/include/llvm-dialects/TableGen/Dialects.h +++ b/include/llvm-dialects/TableGen/Dialects.h @@ -19,6 +19,7 @@ #pragma once #include +#include #include "llvm-dialects/TableGen/Common.h" #include "llvm/ADT/ArrayRef.h" @@ -75,7 +76,7 @@ class GenDialectsContext { void init(RecordKeeperTy &records, const llvm::DenseSet &dialects); - Trait *getTrait(RecordTy *traitRec); + Trait *getTrait(RecordTy *traitRec, int idx = -1); Predicate *getPredicate(const llvm::Init *init, llvm::raw_ostream &errs); Attr *getAttr(RecordTy *record, llvm::raw_ostream &errs); OpClass *getOpClass(RecordTy *opClassRec); @@ -97,6 +98,7 @@ class GenDialectsContext { const llvm::Init *m_any = nullptr; bool m_attrsComplete = false; llvm::DenseMap> m_traits; + llvm::DenseMap>> m_valueTraits; llvm::DenseMap> m_predicates; llvm::DenseMap> m_attrs; llvm::DenseMap> m_opClasses; diff --git a/include/llvm-dialects/TableGen/Traits.h b/include/llvm-dialects/TableGen/Traits.h index 0a4eaad..171ff50 100644 --- a/include/llvm-dialects/TableGen/Traits.h +++ b/include/llvm-dialects/TableGen/Traits.h @@ -40,16 +40,19 @@ class Trait { enum class Kind : uint8_t { LlvmAttributeTrait_First, LlvmEnumAttributeTrait = LlvmAttributeTrait_First, + LlvmEnumFnAttributeTrait, + LlvmEnumRetAttributeTrait, + LlvmEnumParamAttributeTrait, LlvmMemoryAttributeTrait, LlvmAttributeTrait_Last = LlvmMemoryAttributeTrait, }; static std::unique_ptr fromRecord(GenDialectsContext *context, - RecordTy *record); + RecordTy *record, int idx = 0); virtual ~Trait() = default; - virtual void init(GenDialectsContext *context, RecordTy *record); + virtual void init(GenDialectsContext *context, RecordTy *record, int idx); Kind getKind() const { return m_kind; } RecordTy *getRecord() const { return m_record; } diff --git a/lib/TableGen/Dialects.cpp b/lib/TableGen/Dialects.cpp index 6c05eb9..c660361 100644 --- a/lib/TableGen/Dialects.cpp +++ b/lib/TableGen/Dialects.cpp @@ -79,14 +79,14 @@ void GenDialect::finalize(raw_ostream &errs) { GenDialectsContext::GenDialectsContext() = default; GenDialectsContext::~GenDialectsContext() = default; -Trait *GenDialectsContext::getTrait(RecordTy *traitRec) { +Trait *GenDialectsContext::getTrait(RecordTy *traitRec, int idx) { if (!traitRec->isSubClassOf("Trait")) report_fatal_error(Twine("Trying to use '") + traitRec->getName() + "' as a trait, but it is not a subclass of 'Trait'"); - auto &result = m_traits[traitRec]; + auto &result = idx < 0 ? m_traits[traitRec] : m_valueTraits[traitRec][idx]; if (!result) - result = Trait::fromRecord(this, traitRec); + result = Trait::fromRecord(this, traitRec, idx); return result.get(); } diff --git a/lib/TableGen/GenDialect.cpp b/lib/TableGen/GenDialect.cpp index 63b00c3..ba655a9 100644 --- a/lib/TableGen/GenDialect.cpp +++ b/lib/TableGen/GenDialect.cpp @@ -353,9 +353,11 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) { if (!dialect->attribute_lists_empty()) { FmtContextScope scope{fmt}; fmt.addSubst("attrBuilder", "attrBuilder"); + fmt.addSubst("argAttrList", "argAttrList"); for (const auto &enumeratedTraits : enumerate(dialect->attribute_lists())) { out << tgfmt("{\n ::llvm::AttrBuilder $attrBuilder{context};\n", &fmt); + out << tgfmt(" ::llvm::AttributeList $argAttrList;\n", &fmt); for (const Trait *trait : enumeratedTraits.value()) { if (auto *llvmAttribute = dyn_cast(trait)) { @@ -365,8 +367,8 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) { } } - out << tgfmt("m_attributeLists[$0] = ::llvm::AttributeList::get(context, " - "::llvm::AttributeList::FunctionIndex, $attrBuilder);\n}\n", + out << tgfmt("m_attributeLists[$0] = " + "$argAttrList.addFnAttributes(context, $attrBuilder);\n}\n", &fmt, enumeratedTraits.index()); } } diff --git a/lib/TableGen/Operations.cpp b/lib/TableGen/Operations.cpp index 54e1643..b2feaae 100644 --- a/lib/TableGen/Operations.cpp +++ b/lib/TableGen/Operations.cpp @@ -368,6 +368,47 @@ bool Operation::parse(raw_ostream &errs, GenDialectsContext *context, for (RecordTy *traitRec : record->getValueAsListOfDefs("traits")) op->traits.push_back(context->getTrait(traitRec)); + + const RecordVal *insVal = record->getValue("arguments"); + std::unordered_map nameToIndexMap; + if (const DagInit *DI = dyn_cast(insVal->getValue())){ + for (unsigned i = 0; i < DI->getNumArgs(); ++i) { + StringRef name = DI->getArgNameStr(i); + nameToIndexMap[name.str()] = i + 1; + } + } + + const RecordVal *outsVal = record->getValue("results"); + if (const DagInit *DI = dyn_cast(outsVal->getValue())) { + if (DI->getNumArgs() > 0) { + StringRef name = DI->getArgNameStr(0); + nameToIndexMap[name.str()] = 0; + } + } + + const ListInit *List = record->getValueAsListInit("value_traits"); + for (const Init *I : List->getValues()) { + if (const DagInit *DI = dyn_cast(I)) { + if (DI->getNumArgs() != 1) { + errs << "value_traits " << *DI << " is missing argument name"; + return false; + } + + StringRef name = DI->getArgNameStr(0); + + if (const DefInit *Op = dyn_cast(DI->getOperator())) { + op->traits.push_back( + context->getTrait(Op->getDef(), nameToIndexMap[name.str()])); + } else { + errs << "value_traits " << *DI << " is not of form (Trait $arg)"; + return false; + } + } else { + errs << "value_traits was not a list of DAG's"; + return false; + } + } + EvaluationPlanner evaluation(op->m_system); for (const auto &arg : op->getFullArguments()) { diff --git a/lib/TableGen/Traits.cpp b/lib/TableGen/Traits.cpp index 7f5a18a..dc384c8 100644 --- a/lib/TableGen/Traits.cpp +++ b/lib/TableGen/Traits.cpp @@ -36,11 +36,11 @@ static cl::opt NoMemoryEffects( class LlvmEnumAttributeTrait : public LlvmAttributeTrait { public: - LlvmEnumAttributeTrait() : LlvmAttributeTrait(Kind::LlvmEnumAttributeTrait) {} + LlvmEnumAttributeTrait(Kind kind) : LlvmAttributeTrait(kind) {} - void init(GenDialectsContext *context, RecordTy *record) override; + void init(GenDialectsContext *context, RecordTy *record, int idx) override; - void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; + virtual void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override = 0; llvm::StringRef getLlvmEnum() const { return m_llvmEnum; } @@ -52,12 +52,46 @@ class LlvmEnumAttributeTrait : public LlvmAttributeTrait { std::string m_llvmEnum; }; +class LlvmEnumFnAttributeTrait : public LlvmEnumAttributeTrait { +public: + LlvmEnumFnAttributeTrait() + : LlvmEnumAttributeTrait(Kind::LlvmEnumFnAttributeTrait) {} + + void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; +}; +class LlvmEnumRetAttributeTrait : public LlvmEnumAttributeTrait { +public: + LlvmEnumRetAttributeTrait() + : LlvmEnumAttributeTrait(Kind::LlvmEnumRetAttributeTrait) {} + + void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; +}; + +class LlvmEnumParamAttributeTrait : public LlvmEnumAttributeTrait { +public: + LlvmEnumParamAttributeTrait() + : LlvmEnumAttributeTrait(Kind::LlvmEnumParamAttributeTrait) {} + + void init(GenDialectsContext *context, RecordTy *record, int idx) override; + + void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; + + int getIdx() const { return m_idx; } + + static bool classof(const Trait *t) { + return t->getKind() == Kind::LlvmEnumParamAttributeTrait; + } + +private: + int m_idx; +}; + class LlvmMemoryAttributeTrait : public LlvmAttributeTrait { public: LlvmMemoryAttributeTrait() : LlvmAttributeTrait(Kind::LlvmMemoryAttributeTrait) {} - void init(GenDialectsContext *context, RecordTy *record) override; + void init(GenDialectsContext *context, RecordTy *record, int idx) override; void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; @@ -83,40 +117,66 @@ bool llvm_dialects::noMemoryEffects() { } std::unique_ptr Trait::fromRecord(GenDialectsContext *context, - RecordTy *traitRec) { + RecordTy *traitRec, int idx) { std::unique_ptr result; if (traitRec->isSubClassOf("LlvmEnumAttributeTrait")) { - result = std::make_unique(); + if (idx < 0) { + result = std::make_unique(); + } else if (idx == 0) { + result = std::make_unique(); + } else { + result = std::make_unique(); + } } else if (traitRec->isSubClassOf("Memory")) { result = std::make_unique(); } else { report_fatal_error(Twine("unsupported trait: ") + traitRec->getName()); } - result->init(context, traitRec); + result->init(context, traitRec, idx); return result; } -void Trait::init(GenDialectsContext *context, RecordTy *record) { +void Trait::init(GenDialectsContext *context, RecordTy *record, int idx) { m_record = record; } StringRef Trait::getName() const { return m_record->getName(); } void LlvmEnumAttributeTrait::init(GenDialectsContext *context, - RecordTy *record) { - LlvmAttributeTrait::init(context, record); + RecordTy *record, int idx) { + LlvmAttributeTrait::init(context, record, idx); m_llvmEnum = record->getValueAsString("llvmEnum"); } -void LlvmEnumAttributeTrait::addAttribute(raw_ostream &out, - FmtContext &fmt) const { - out << tgfmt("$attrBuilder.addAttribute(::llvm::Attribute::$0);\n", &fmt, - getLlvmEnum()); +void LlvmEnumParamAttributeTrait::init(GenDialectsContext *context, + RecordTy *record, int idx) { + LlvmEnumAttributeTrait::init(context, record, idx); + m_idx = idx - 1; +} + +void LlvmEnumFnAttributeTrait::addAttribute(raw_ostream &out, + FmtContext &fmt) const { + out << tgfmt("$attrBuilder.addAttribute(::llvm::Attribute::$0);\n", &fmt, + getLlvmEnum()); +} + +void LlvmEnumRetAttributeTrait::addAttribute(raw_ostream &out, + FmtContext &fmt) const { + out << tgfmt("$argAttrList = $argAttrList.addRetAttribute(context, " + "::llvm::Attribute::$0);\n", + &fmt, getLlvmEnum()); +} + +void LlvmEnumParamAttributeTrait::addAttribute(raw_ostream &out, + FmtContext &fmt) const { + out << tgfmt("$argAttrList = $argAttrList.addParamAttribute(context, $0, " + "::llvm::Attribute::$1);\n", + &fmt, getIdx(), getLlvmEnum()); } void LlvmMemoryAttributeTrait::init(GenDialectsContext *context, - RecordTy *record) { - LlvmAttributeTrait::init(context, record); + RecordTy *record, int idx) { + LlvmAttributeTrait::init(context, record, idx); auto *effects = record->getValueAsListInit("effects"); for (auto *effectInit : *effects) { diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 06f313a..dd47655 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -185,35 +185,61 @@ namespace xd::cpp { ExampleDialect::ExampleDialect(::llvm::LLVMContext& context) : DialectImpl(context) { { ::llvm::AttrBuilder attrBuilder{context}; + ::llvm::AttributeList argAttrList; +argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoCapture); +argAttrList = argAttrList.addParamAttribute(context, 1, ::llvm::Attribute::NoCapture); +argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoUndef); +argAttrList = argAttrList.addRetAttribute(context, ::llvm::Attribute::NoUndef); attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); -m_attributeLists[0] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); +m_attributeLists[0] = argAttrList.addFnAttributes(context, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; + ::llvm::AttributeList argAttrList; +argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoCapture); +attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); +attrBuilder.addAttribute(::llvm::Attribute::WillReturn); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); +m_attributeLists[1] = argAttrList.addFnAttributes(context, attrBuilder); +} +{ + ::llvm::AttrBuilder attrBuilder{context}; + ::llvm::AttributeList argAttrList; +attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); +attrBuilder.addAttribute(::llvm::Attribute::WillReturn); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); +m_attributeLists[2] = argAttrList.addFnAttributes(context, attrBuilder); +} +{ + ::llvm::AttrBuilder attrBuilder{context}; + ::llvm::AttributeList argAttrList; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); -m_attributeLists[1] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); +m_attributeLists[3] = argAttrList.addFnAttributes(context, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; + ::llvm::AttributeList argAttrList; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod)); -m_attributeLists[2] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); +m_attributeLists[4] = argAttrList.addFnAttributes(context, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; + ::llvm::AttributeList argAttrList; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::ModRef)); -m_attributeLists[3] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); +m_attributeLists[5] = argAttrList.addFnAttributes(context, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; + ::llvm::AttributeList argAttrList; attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -m_attributeLists[4] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); +m_attributeLists[6] = argAttrList.addFnAttributes(context, attrBuilder); } } @@ -484,7 +510,7 @@ uint32_t const extra = getExtra(); const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {lhs->getType()}); @@ -579,7 +605,7 @@ rhs const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {::llvm::cast(vector->getType())->getElementType()}); @@ -683,7 +709,7 @@ index const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -853,7 +879,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(2); auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), { }, false); @@ -915,7 +941,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -1013,7 +1039,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -1111,7 +1137,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(4); + = ExampleDialect::get(context).getAttributeList(6); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), { ::llvm::IntegerType::get(context, 1), }, false); @@ -1180,7 +1206,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {vector->getType()}); @@ -1302,7 +1328,7 @@ index const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(4); + = ExampleDialect::get(context).getAttributeList(6); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1386,7 +1412,7 @@ instName_0 const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(4); + = ExampleDialect::get(context).getAttributeList(6); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1460,7 +1486,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(4); + = ExampleDialect::get(context).getAttributeList(6); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1640,7 +1666,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(3); + = ExampleDialect::get(context).getAttributeList(5); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {dataType}); @@ -1697,7 +1723,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(3); + = ExampleDialect::get(context).getAttributeList(5); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {dataType}); @@ -1754,7 +1780,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(4); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1817,7 +1843,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(1); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 64), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1891,7 +1917,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -1983,7 +2009,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -2075,7 +2101,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -2167,7 +2193,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(4); + = ExampleDialect::get(context).getAttributeList(6); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), { ::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0), }, false); @@ -2236,7 +2262,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(4); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -2299,7 +2325,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(4); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index ecb210a..5ce84c5 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -47,7 +47,7 @@ namespace xd::cpp { } private: - ::std::array<::llvm::AttributeList, 5> m_attributeLists; + ::std::array<::llvm::AttributeList, 7> m_attributeLists; }; class XdHandleType : public ::llvm::TargetExtType { From 3be8be77fd13afcb3a3f9a55e28c5abbec3f6096 Mon Sep 17 00:00:00 2001 From: Stefan Stipanovic Date: Tue, 11 Mar 2025 00:15:34 +0100 Subject: [PATCH 2/4] address review comments --- example/ExampleDialect.td | 28 +++- include/llvm-dialects/Dialect/Dialect.td | 2 + include/llvm-dialects/TableGen/Operations.h | 6 +- lib/TableGen/Operations.cpp | 88 ++++++----- lib/TableGen/Traits.cpp | 73 ++------- test/example/generated/ExampleDialect.cpp.inc | 148 ++++++++++++++++-- test/example/generated/ExampleDialect.h.inc | 34 ++++ 7 files changed, 256 insertions(+), 123 deletions(-) diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index 358ad7c..358767d 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -141,13 +141,6 @@ def Add32Op : ExampleOp<"add32", [Memory<[]>, NoUnwind, WillReturn]> { let results = (outs I32:$result); let arguments = (ins I32:$lhs, I32:$rhs, AttrI32:$extra); - let value_traits = [ - (NoCapture $lhs), - (NoCapture $rhs), - (NoUndef $lhs), - (NoUndef $result), - ]; - let summary = "add two numbers, and a little extra"; let description = [{ For those times when you want a little extra, this operation adds two @@ -246,6 +239,10 @@ def IExtOp : ExampleOp<"iext", [Memory<[]>, NoUnwind, WillReturn]> { def StreamReduceOp : OpClass { let arguments = (ins Ptr:$ptr, I64:$count, value:$initial); + let value_traits = [ + (NoCapture $ptr) + ]; + let summary = "family of operations that reduce some array in memory"; let description = [{ Illustrate the use of the OpClass feature. @@ -347,3 +344,20 @@ def NoDescriptionOp : Op { let summary = "Some summary"; } + +def BufferCompareOp : Op { + let results = (outs I32:$ret); + let arguments = (ins Ptr:$lhs, Ptr:$rhs); + + let value_traits = [ + (NoCapture $lhs), + (NoCapture $rhs), + (NoUndef $lhs), + (NoUndef $ret), + ]; + + let summary = "demonstrate how multiple parameter attributes are added"; + let description = [{ + Both arguments get a parameter attribute, as well as return value + }]; +} diff --git a/include/llvm-dialects/Dialect/Dialect.td b/include/llvm-dialects/Dialect/Dialect.td index 1a1147f..45ba0e4 100644 --- a/include/llvm-dialects/Dialect/Dialect.td +++ b/include/llvm-dialects/Dialect/Dialect.td @@ -407,6 +407,8 @@ class OpClass : OpClassBase { dag arguments = ?; + list value_traits = []; + string summary = ?; string description = ?; } diff --git a/include/llvm-dialects/TableGen/Operations.h b/include/llvm-dialects/TableGen/Operations.h index 8bddc4d..ae4a545 100644 --- a/include/llvm-dialects/TableGen/Operations.h +++ b/include/llvm-dialects/TableGen/Operations.h @@ -67,6 +67,11 @@ class OperationBase { void emitArgumentAccessorDefinitions(llvm::raw_ostream &out, FmtContext &fmt) const; + void parseValueTraits(llvm::raw_ostream &errs, RecordTy *record, + GenDialectsContext &context); + + std::vector traits; + protected: bool init(llvm::raw_ostream &errs, GenDialectsContext &context, RecordTy *record); @@ -105,7 +110,6 @@ class Operation : public OperationBase { std::string mnemonic; std::string summary; std::string description; - std::vector traits; std::vector results; diff --git a/lib/TableGen/Operations.cpp b/lib/TableGen/Operations.cpp index b2feaae..0350219 100644 --- a/lib/TableGen/Operations.cpp +++ b/lib/TableGen/Operations.cpp @@ -22,6 +22,7 @@ #include "llvm-dialects/TableGen/Dialects.h" #include "llvm-dialects/TableGen/Format.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/TableGen/Record.h" using namespace llvm; @@ -90,6 +91,11 @@ bool OperationBase::init(raw_ostream &errs, GenDialectsContext &context, m_arguments = std::move(*arguments); + if (m_superclass && m_superclass->traits.size() > 0) + traits = m_superclass->traits; + + parseValueTraits(errs, record, context); + // Don't allow any other arguments if the superclass already uses // variadic arguments, as the arguments will be appended to the arguments of // the superclass. @@ -297,6 +303,47 @@ void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out, } } +void OperationBase::parseValueTraits(raw_ostream &errs, RecordTy *record, + GenDialectsContext &context) { + const DagInit *insDag = record->getValueAsDag("arguments"); + std::unordered_map nameToIndexMap; + for (unsigned i = 0; i < insDag->getNumArgs(); ++i) { + StringRef name = insDag->getArgNameStr(i); + nameToIndexMap[name.str()] = i + 1; + } + + const RecordVal *outsVal = record->getValue("results"); + if (outsVal) { + const DagInit *DI = cast(outsVal->getValue()); + if (DI->getNumArgs() > 0) { + StringRef name = DI->getArgNameStr(0); + nameToIndexMap[name.str()] = 0; + } + } + + const ListInit *List = record->getValueAsListInit("value_traits"); + for (const Init *I : List->getValues()) { + if (const DagInit *DI = dyn_cast(I)) { + if (DI->getNumArgs() != 1) { + errs << "value_traits " << *DI << " is missing argument name"; + return; + } + + StringRef name = DI->getArgNameStr(0); + + if (const DefInit *Op = dyn_cast(DI->getOperator())) { + traits.push_back( + context.getTrait(Op->getDef(), nameToIndexMap[name.str()])); + } else { + errs << "value_traits " << *DI << " is not of form (Trait $arg)"; + return; + } + } else { + report_fatal_error("value_traits was not a list of DAG's"); + } + } +} + std::unique_ptr OpClass::parse(raw_ostream &errs, GenDialectsContext &context, RecordTy *record) { @@ -368,47 +415,6 @@ bool Operation::parse(raw_ostream &errs, GenDialectsContext *context, for (RecordTy *traitRec : record->getValueAsListOfDefs("traits")) op->traits.push_back(context->getTrait(traitRec)); - - const RecordVal *insVal = record->getValue("arguments"); - std::unordered_map nameToIndexMap; - if (const DagInit *DI = dyn_cast(insVal->getValue())){ - for (unsigned i = 0; i < DI->getNumArgs(); ++i) { - StringRef name = DI->getArgNameStr(i); - nameToIndexMap[name.str()] = i + 1; - } - } - - const RecordVal *outsVal = record->getValue("results"); - if (const DagInit *DI = dyn_cast(outsVal->getValue())) { - if (DI->getNumArgs() > 0) { - StringRef name = DI->getArgNameStr(0); - nameToIndexMap[name.str()] = 0; - } - } - - const ListInit *List = record->getValueAsListInit("value_traits"); - for (const Init *I : List->getValues()) { - if (const DagInit *DI = dyn_cast(I)) { - if (DI->getNumArgs() != 1) { - errs << "value_traits " << *DI << " is missing argument name"; - return false; - } - - StringRef name = DI->getArgNameStr(0); - - if (const DefInit *Op = dyn_cast(DI->getOperator())) { - op->traits.push_back( - context->getTrait(Op->getDef(), nameToIndexMap[name.str()])); - } else { - errs << "value_traits " << *DI << " is not of form (Trait $arg)"; - return false; - } - } else { - errs << "value_traits was not a list of DAG's"; - return false; - } - } - EvaluationPlanner evaluation(op->m_system); for (const auto &arg : op->getFullArguments()) { diff --git a/lib/TableGen/Traits.cpp b/lib/TableGen/Traits.cpp index dc384c8..186b6ea 100644 --- a/lib/TableGen/Traits.cpp +++ b/lib/TableGen/Traits.cpp @@ -36,53 +36,23 @@ static cl::opt NoMemoryEffects( class LlvmEnumAttributeTrait : public LlvmAttributeTrait { public: - LlvmEnumAttributeTrait(Kind kind) : LlvmAttributeTrait(kind) {} + LlvmEnumAttributeTrait() : LlvmAttributeTrait(Kind::LlvmEnumAttributeTrait) {} void init(GenDialectsContext *context, RecordTy *record, int idx) override; - virtual void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override = 0; + virtual void addAttribute(llvm::raw_ostream &out, + FmtContext &fmt) const override; llvm::StringRef getLlvmEnum() const { return m_llvmEnum; } + int getIdx() const { return m_idx - 1; } + static bool classof(const Trait *t) { return t->getKind() == Kind::LlvmEnumAttributeTrait; } private: std::string m_llvmEnum; -}; - -class LlvmEnumFnAttributeTrait : public LlvmEnumAttributeTrait { -public: - LlvmEnumFnAttributeTrait() - : LlvmEnumAttributeTrait(Kind::LlvmEnumFnAttributeTrait) {} - - void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; -}; -class LlvmEnumRetAttributeTrait : public LlvmEnumAttributeTrait { -public: - LlvmEnumRetAttributeTrait() - : LlvmEnumAttributeTrait(Kind::LlvmEnumRetAttributeTrait) {} - - void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; -}; - -class LlvmEnumParamAttributeTrait : public LlvmEnumAttributeTrait { -public: - LlvmEnumParamAttributeTrait() - : LlvmEnumAttributeTrait(Kind::LlvmEnumParamAttributeTrait) {} - - void init(GenDialectsContext *context, RecordTy *record, int idx) override; - - void addAttribute(llvm::raw_ostream &out, FmtContext &fmt) const override; - - int getIdx() const { return m_idx; } - - static bool classof(const Trait *t) { - return t->getKind() == Kind::LlvmEnumParamAttributeTrait; - } - -private: int m_idx; }; @@ -120,13 +90,7 @@ std::unique_ptr Trait::fromRecord(GenDialectsContext *context, RecordTy *traitRec, int idx) { std::unique_ptr result; if (traitRec->isSubClassOf("LlvmEnumAttributeTrait")) { - if (idx < 0) { - result = std::make_unique(); - } else if (idx == 0) { - result = std::make_unique(); - } else { - result = std::make_unique(); - } + result = std::make_unique(); } else if (traitRec->isSubClassOf("Memory")) { result = std::make_unique(); } else { @@ -146,29 +110,22 @@ void LlvmEnumAttributeTrait::init(GenDialectsContext *context, RecordTy *record, int idx) { LlvmAttributeTrait::init(context, record, idx); m_llvmEnum = record->getValueAsString("llvmEnum"); + m_idx = idx; } -void LlvmEnumParamAttributeTrait::init(GenDialectsContext *context, - RecordTy *record, int idx) { - LlvmEnumAttributeTrait::init(context, record, idx); - m_idx = idx - 1; -} - -void LlvmEnumFnAttributeTrait::addAttribute(raw_ostream &out, - FmtContext &fmt) const { +void LlvmEnumAttributeTrait::addAttribute(raw_ostream &out, + FmtContext &fmt) const { + // Function attribute. + if (m_idx < 0) out << tgfmt("$attrBuilder.addAttribute(::llvm::Attribute::$0);\n", &fmt, getLlvmEnum()); -} - -void LlvmEnumRetAttributeTrait::addAttribute(raw_ostream &out, - FmtContext &fmt) const { + // Return attribute. + else if (m_idx == 0) out << tgfmt("$argAttrList = $argAttrList.addRetAttribute(context, " "::llvm::Attribute::$0);\n", &fmt, getLlvmEnum()); -} - -void LlvmEnumParamAttributeTrait::addAttribute(raw_ostream &out, - FmtContext &fmt) const { + // Param attribute. + else out << tgfmt("$argAttrList = $argAttrList.addParamAttribute(context, $0, " "::llvm::Attribute::$1);\n", &fmt, getIdx(), getLlvmEnum()); diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index dd47655..ea63945 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -57,6 +57,11 @@ namespace xd::cpp { state.setError(); }); + builder.add([](::llvm_dialects::VerifierState &state, BufferCompareOp &op) { + if (!op.verifier(state.out())) + state.setError(); + }); + builder.add([](::llvm_dialects::VerifierState &state, CombineOp &op) { if (!op.verifier(state.out())) state.setError(); @@ -192,7 +197,6 @@ argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoUnd argAttrList = argAttrList.addRetAttribute(context, ::llvm::Attribute::NoUndef); attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); m_attributeLists[0] = argAttrList.addFnAttributes(context, attrBuilder); } { @@ -207,9 +211,10 @@ m_attributeLists[1] = argAttrList.addFnAttributes(context, attrBuilder); { ::llvm::AttrBuilder attrBuilder{context}; ::llvm::AttributeList argAttrList; +argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoCapture); attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); m_attributeLists[2] = argAttrList.addFnAttributes(context, attrBuilder); } { @@ -217,7 +222,7 @@ m_attributeLists[2] = argAttrList.addFnAttributes(context, attrBuilder); ::llvm::AttributeList argAttrList; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); m_attributeLists[3] = argAttrList.addFnAttributes(context, attrBuilder); } { @@ -388,7 +393,7 @@ return true; const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(3); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), { lhs->getType(), rhs->getType(), @@ -501,6 +506,109 @@ uint32_t const extra = getExtra(); + const ::llvm::StringLiteral BufferCompareOp::s_name{"xd.ir.buffer.compare.op"}; + + BufferCompareOp* BufferCompareOp::create(llvm_dialects::Builder& b, ::llvm::Value * lhs, ::llvm::Value * rhs, const llvm::Twine &instName) { + ::llvm::LLVMContext& context = b.getContext(); + (void)context; + ::llvm::Module& module = *b.GetInsertBlock()->getModule(); + + + const ::llvm::AttributeList attrs + = ExampleDialect::get(context).getAttributeList(0); + auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), { +lhs->getType(), +rhs->getType(), +}, false); + + auto fn = module.getOrInsertFunction(s_name, fnType, attrs); + ::llvm::SmallString<32> newName; + for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) || + ::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) { + // If a function with the same name but a different types already exists, + // we get a bitcast of a function or a function with the wrong type. + // Try new names until we get one with the correct type. + newName = ""; + ::llvm::raw_svector_ostream newNameStream(newName); + newNameStream << s_name << "_" << i; + fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs); + } + assert(::llvm::isa<::llvm::Function>(fn.getCallee())); + assert(fn.getFunctionType() == fnType); + assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType()); + + ::llvm::SmallVector<::llvm::Value*, 2> args = { +lhs, +rhs + }; + + return ::llvm::cast(b.CreateCall(fn, args, instName)); + } + + + bool BufferCompareOp::verifier(::llvm::raw_ostream &errs) { + ::llvm::LLVMContext &context = getModule()->getContext(); + (void)context; + + using ::llvm_dialects::printable; + + if (arg_size() != 2) { + errs << " wrong number of arguments: " << arg_size() + << ", expected 2\n"; + return false; + } + ::llvm::Type * const lhsType = getLhs()->getType(); +(void)lhsType; +::llvm::Type * const rhsType = getRhs()->getType(); +(void)rhsType; +::llvm::Type * const retType = getRet()->getType(); +(void)retType; + + if (::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0) != lhsType) { + errs << " unexpected value of $lhs:\n"; + errs << " expected: " << printable(::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0)) << '\n'; + errs << " actual: " << printable(lhsType) << '\n'; + + return false; + } + + if (::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0) != rhsType) { + errs << " unexpected value of $rhs:\n"; + errs << " expected: " << printable(::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0)) << '\n'; + errs << " actual: " << printable(rhsType) << '\n'; + + return false; + } + + if (::llvm::IntegerType::get(context, 32) != retType) { + errs << " unexpected value of $ret:\n"; + errs << " expected: " << printable(::llvm::IntegerType::get(context, 32)) << '\n'; + errs << " actual: " << printable(retType) << '\n'; + + return false; + } + return true; +} + + + ::llvm::Value * BufferCompareOp::getLhs() const { + return getArgOperand(ArgumentIndex::Lhs); + } + + void BufferCompareOp::setLhs(::llvm::Value * lhs) { + setArgOperand(ArgumentIndex::Lhs, lhs); + } + ::llvm::Value * BufferCompareOp::getRhs() const { + return getArgOperand(ArgumentIndex::Rhs); + } + + void BufferCompareOp::setRhs(::llvm::Value * rhs) { + setArgOperand(ArgumentIndex::Rhs, rhs); + } +::llvm::Value *BufferCompareOp::getRet() {return this;} + + + const ::llvm::StringLiteral CombineOp::s_name{"xd.ir.combine"}; CombineOp* CombineOp::create(llvm_dialects::Builder& b, ::llvm::Value * lhs, ::llvm::Value * rhs, const llvm::Twine &instName) { @@ -510,7 +618,7 @@ uint32_t const extra = getExtra(); const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {lhs->getType()}); @@ -605,7 +713,7 @@ rhs const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {::llvm::cast(vector->getType())->getElementType()}); @@ -709,7 +817,7 @@ index const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -879,7 +987,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(3); auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), { }, false); @@ -941,7 +1049,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -1039,7 +1147,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -1206,7 +1314,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(3); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {vector->getType()}); @@ -1564,7 +1672,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(4); + = ExampleDialect::get(context).getAttributeList(6); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), { }, false); @@ -1615,7 +1723,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(4); + = ExampleDialect::get(context).getAttributeList(6); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), { }, false); @@ -1917,7 +2025,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(3); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -2009,7 +2117,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(3); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -2101,7 +2209,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(3); + = ExampleDialect::get(context).getAttributeList(2); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -2407,6 +2515,14 @@ data } + template <> + const ::llvm_dialects::OpDescription & + ::llvm_dialects::OpDescription::get() { + static const ::llvm_dialects::OpDescription desc{false, "xd.ir.buffer.compare.op"}; + return desc; + } + + template <> const ::llvm_dialects::OpDescription & ::llvm_dialects::OpDescription::get() { diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 5ce84c5..91f9b47 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -151,6 +151,40 @@ Extra = 2, ::llvm::Value * getResult(); + }; + + /// BufferCompareOp +/// demonstrate how multiple parameter attributes are added +/// +/// Both arguments get a parameter attribute, as well as return value +/// + + class BufferCompareOp : public ::llvm::CallInst { + static const ::llvm::StringLiteral s_name; //{"xd.ir.buffer.compare.op"}; + + public: + static bool classof(const ::llvm::CallInst* i) { + return ::llvm_dialects::detail::isSimpleOperation(i, s_name); + } + static bool classof(const ::llvm::Value* v) { + return ::llvm::isa<::llvm::CallInst>(v) && + classof(::llvm::cast<::llvm::CallInst>(v)); + } + static BufferCompareOp* create(::llvm_dialects::Builder& b, ::llvm::Value * lhs, ::llvm::Value * rhs, const llvm::Twine &instName = ""); + +bool verifier(::llvm::raw_ostream &errs); + +::llvm::Value * getLhs() const; + void setLhs(::llvm::Value * lhs); + ::llvm::Value * getRhs() const; + void setRhs(::llvm::Value * rhs); + struct ArgumentIndex { enum Enum : uint32_t { +Lhs = 0, +Rhs = 1, +};}; +::llvm::Value * getRet(); + + }; /// CombineOp From 5f97bae1eb874c819db9b45719e4dc609edd8fbb Mon Sep 17 00:00:00 2001 From: Stefan Stipanovic Date: Tue, 11 Mar 2025 13:09:19 +0100 Subject: [PATCH 3/4] change how NoCapture is emmited: nocapture -> captures(none) --- lib/TableGen/Traits.cpp | 26 ++++++++++++------- test/example/generated/ExampleDialect.cpp.inc | 8 +++--- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/lib/TableGen/Traits.cpp b/lib/TableGen/Traits.cpp index 186b6ea..797713f 100644 --- a/lib/TableGen/Traits.cpp +++ b/lib/TableGen/Traits.cpp @@ -115,20 +115,28 @@ void LlvmEnumAttributeTrait::init(GenDialectsContext *context, void LlvmEnumAttributeTrait::addAttribute(raw_ostream &out, FmtContext &fmt) const { - // Function attribute. - if (m_idx < 0) + if (m_idx < 0) { + // Function attribute. out << tgfmt("$attrBuilder.addAttribute(::llvm::Attribute::$0);\n", &fmt, getLlvmEnum()); - // Return attribute. - else if (m_idx == 0) + } else if (m_idx == 0) { + // Return attribute. out << tgfmt("$argAttrList = $argAttrList.addRetAttribute(context, " "::llvm::Attribute::$0);\n", &fmt, getLlvmEnum()); - // Param attribute. - else - out << tgfmt("$argAttrList = $argAttrList.addParamAttribute(context, $0, " - "::llvm::Attribute::$1);\n", - &fmt, getIdx(), getLlvmEnum()); + } else { + // Param attribute. + if (getLlvmEnum() == "NoCapture") { + out << tgfmt("$argAttrList = $argAttrList.addParamAttribute(context, $0, " + "::llvm::Attribute::getWithCaptureInfo(context, " + "llvm::CaptureInfo::none()));\n", + &fmt, getIdx()); + } else { + out << tgfmt("$argAttrList = $argAttrList.addParamAttribute(context, $0, " + "::llvm::Attribute::$1);\n", + &fmt, getIdx(), getLlvmEnum()); + } + } } void LlvmMemoryAttributeTrait::init(GenDialectsContext *context, diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index ea63945..abd6fea 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -191,8 +191,8 @@ namespace xd::cpp { { ::llvm::AttrBuilder attrBuilder{context}; ::llvm::AttributeList argAttrList; -argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoCapture); -argAttrList = argAttrList.addParamAttribute(context, 1, ::llvm::Attribute::NoCapture); +argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::getWithCaptureInfo(context, llvm::CaptureInfo::none())); +argAttrList = argAttrList.addParamAttribute(context, 1, ::llvm::Attribute::getWithCaptureInfo(context, llvm::CaptureInfo::none())); argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoUndef); argAttrList = argAttrList.addRetAttribute(context, ::llvm::Attribute::NoUndef); attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); @@ -202,7 +202,7 @@ m_attributeLists[0] = argAttrList.addFnAttributes(context, attrBuilder); { ::llvm::AttrBuilder attrBuilder{context}; ::llvm::AttributeList argAttrList; -argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoCapture); +argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::getWithCaptureInfo(context, llvm::CaptureInfo::none())); attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); @@ -211,7 +211,7 @@ m_attributeLists[1] = argAttrList.addFnAttributes(context, attrBuilder); { ::llvm::AttrBuilder attrBuilder{context}; ::llvm::AttributeList argAttrList; -argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::NoCapture); +argAttrList = argAttrList.addParamAttribute(context, 0, ::llvm::Attribute::getWithCaptureInfo(context, llvm::CaptureInfo::none())); attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); From 5e44f3214c00ab4f56f253ca7441da4180d74fc0 Mon Sep 17 00:00:00 2001 From: Stefan Stipanovic Date: Tue, 11 Mar 2025 14:54:24 +0100 Subject: [PATCH 4/4] fix test --- test/example/generated/ExampleDialect.h.inc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 91f9b47..cc1874e 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -158,6 +158,9 @@ Extra = 2, /// /// Both arguments get a parameter attribute, as well as return value /// +/// Arguments +/// * Value * lhs +/// * Value * rhs class BufferCompareOp : public ::llvm::CallInst { static const ::llvm::StringLiteral s_name; //{"xd.ir.buffer.compare.op"};