diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index 358767d..8f6d172 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -361,3 +361,41 @@ def BufferCompareOp : Op { + let summary = "a custom struct-backed type"; + let description = [{ + Test that a struct-backed type works correctly. + }]; + let typeArguments = (args AttrI32:$field0, AttrI32:$field1, AttrI32:$field2); + let representation = (repr_struct (IntegerType 41)); + + let defaultGetterHasExplicitContextArgument = 1; +} + +def DummyStructBackedInpOp : Op { + let summary = "a custom op using input arg with struct-backed type"; + let description = [{ + Test that an operation that takes argument with struct-backed type works correctly. + }]; + let arguments = (ins + StructBackedType:$inp + ); + let results = (outs + I32:$ret + ); +} + +def DummyStructBackedOutpOp : Op { + let summary = "a custom op returning value with struct-backed type"; + let description = [{ + Test that an operation that returns value with struct-backed type works correctly. + }]; + let arguments = (ins + I32:$extra_arg + ); + let results = (outs + StructBackedType:$outp + ); + let defaultBuilderHasExplicitResultType = true; +} diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index 646532d..b78c601 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -149,6 +149,10 @@ void createFunctionExample(Module &module, const Twine &name) { b.create("Hello world!"); + xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, 2); + auto *structBackedVal = b.create(structBackedTy, b.getInt32(42), "gen.struct.backed.val"); + b.create(structBackedVal, "consume.struct.backed.val"); + b.CreateRetVoid(); } diff --git a/include/llvm-dialects/Dialect/Dialect.td b/include/llvm-dialects/Dialect/Dialect.td index 9acae9d..8d1d623 100644 --- a/include/llvm-dialects/Dialect/Dialect.td +++ b/include/llvm-dialects/Dialect/Dialect.td @@ -215,6 +215,9 @@ def F64 : SpecialBuiltinType<"Double">; def F32 : SpecialBuiltinType<"Float">; def F16 : SpecialBuiltinType<"Half">; +def repr_targetext; +def repr_struct; + /// All types that are defined by a dialect are derived from this class. class DialectType : Type, Predicate { dag typeArguments = ?; @@ -229,6 +232,13 @@ class DialectType : Type, Predicate { string summary = ?; string description = ?; + + /// How the dialect type is represented in LLVM IR: + /// - (repr_targetext): use a TargetExtType (the default) + /// - (repr_struct ): use a StructType with the given type as the + /// discriminant; the discriminant should be a type that cannot naturally + /// appear elsewhere, e.g. (repr_struct (IntegerType 41)) + dag representation = (repr_targetext); } def and; diff --git a/include/llvm-dialects/TableGen/DialectType.h b/include/llvm-dialects/TableGen/DialectType.h index 5dd1a18..8625fab 100644 --- a/include/llvm-dialects/TableGen/DialectType.h +++ b/include/llvm-dialects/TableGen/DialectType.h @@ -75,6 +75,9 @@ class DialectType : public BaseCppPredicate { std::string m_context; std::vector m_getterArguments; unsigned m_argBegin = 0; + + bool m_structBacked = false; + unsigned m_structSentinelBitWidth; }; } // namespace llvm_dialects diff --git a/lib/TableGen/DialectType.cpp b/lib/TableGen/DialectType.cpp index 183c52c..17c2916 100644 --- a/lib/TableGen/DialectType.cpp +++ b/lib/TableGen/DialectType.cpp @@ -44,6 +44,23 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context, m_summary = record->getValueAsString("summary"); m_description = record->getValueAsString("description"); + if (auto *dag = + cast(record->getValue("representation")->getValue())) { + if (cast(dag->getOperator())->getDef()->getName() == + "repr_struct") { + m_structBacked = true; + + if (dag->getNumArgs() != 1) { + errs << "'repr_struct' expects exactly one type argument\n"; + return false; + } + m_structSentinelBitWidth = + llvm::cast( + llvm::cast(dag->getArg(0))->getArg(0)) + ->getValue(); + } + } + for (unsigned argIdx = 0; argIdx < m_arguments.size(); ++argIdx) m_canDerive.push_back(true); m_canCheckFromSelf = true; @@ -72,8 +89,11 @@ bool DialectType::init(raw_ostream &errs, GenDialectsContext &context, evaluate << ')'; } - - m_check = tgfmt("::llvm::isa<$_type>($$self)", &fmt); + if (m_structBacked) { + m_check = tgfmt("::llvm::isa<::llvm::StructType>($$self)", &fmt); + } else { + m_check = tgfmt("::llvm::isa<$_type>($$self)", &fmt); + } if (!m_arguments[0].type || !m_arguments[0].type->isTypeArg() || m_arguments[0].constraint) { @@ -150,7 +170,56 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { fmt.addSubst("_type", getName()); fmt.addSubst("mnemonic", getMnemonic()); - out << tgfmt(R"( + if (m_structBacked) { + out << tgfmt(R"( + class $_type : public ::llvm::StructType { + using ::llvm::StructType::StructType; + public: + static constexpr ::llvm::StringLiteral s_prefix{"$dialect.$mnemonic."}; + + using ::llvm::StructType::getElementType; + + static $_type *get( + )", + &fmt); + + bool contextPresent = + !m_getterArguments.empty() && m_getterArguments.front().cppType.find( + "LLVMContext") != std::string::npos; + if (!contextPresent) { + out << "::llvm::LLVMContext &" << m_context; + if (!m_getterArguments.empty()) + out << ", "; + } + for (const auto &argument : llvm::enumerate(m_getterArguments)) { + if (argument.index() != 0) + out << ", "; + out << argument.value().cppType << ' ' << argument.value().name; + } + out << ");\n\n"; + + out << " static bool classof(const ::llvm::Type *t);\n\n"; + + unsigned fieldIdx = 1; // sentinel + for (const auto &argument : typeArguments()) { + std::string camel = convertToCamelFromSnakeCase(argument.name, true); + out << tgfmt( + R"( unsigned get$0() const { + ::llvm::Type *elt = getElementType($1); + if (elt->isStructTy()) + return 0; + return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + } +)", + &fmt, camel, fieldIdx++); + } + + out << " };\n\n"; + } else { + + // TargetExtType + + out << tgfmt(R"( class $_type : public ::llvm::TargetExtType { static constexpr ::llvm::StringLiteral s_name{"$dialect.$mnemonic"}; @@ -166,22 +235,23 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { bool verifier(::llvm::raw_ostream &errs) const; )", - &fmt); + &fmt); - out << tgfmt("static $_type *get(", &fmt); - for (const auto &argument : llvm::enumerate(m_getterArguments)) { - if (argument.index() != 0) - out << ", "; - out << argument.value().cppType << ' ' << argument.value().name; - } - out << ");\n\n"; + out << tgfmt("static $_type *get(", &fmt); + for (const auto &argument : llvm::enumerate(m_getterArguments)) { + if (argument.index() != 0) + out << ", "; + out << argument.value().cppType << ' ' << argument.value().name; + } + out << ");\n\n"; - for (const auto &argument : typeArguments()) { - out << tgfmt("$0 get$1() const;\n", &fmt, argument.type->getCppType(), - convertToCamelFromSnakeCase(argument.name, true)); - } + for (const auto &argument : typeArguments()) { + out << tgfmt("$0 get$1() const;\n", &fmt, argument.type->getCppType(), + convertToCamelFromSnakeCase(argument.name, true)); + } - out << "};\n\n"; + out << "};\n\n"; + } } void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { @@ -194,77 +264,147 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { fmt.addSubst("types", symbols.chooseName("types")); fmt.addSubst("ints", symbols.chooseName("ints")); fmt.addSubst("_errs", symbols.chooseName("errs")); + fmt.addSubst("os", symbols.chooseName("os")); + fmt.addSubst("name", symbols.chooseName("name")); + fmt.addSubst("fields", symbols.chooseName("fields")); + fmt.addSubst("st", symbols.chooseName("st")); + + if (m_structBacked) { + out << tgfmt("$_type* $_type::get(", &fmt); + bool contextPresent = + !m_getterArguments.empty() && m_getterArguments.front().cppType.find( + "LLVMContext") != std::string::npos; + if (!contextPresent) { + out << "::llvm::LLVMContext &" << m_context; + if (!m_getterArguments.empty()) + out << ", "; + } + for (auto argument : llvm::enumerate(m_getterArguments)) { + if (argument.index() != 0) + out << ", "; + out << argument.value().cppType << ' ' << argument.value().name; + } + out << ") {\n"; - // Output the type argument getters. - unsigned typeIdx = 0; - unsigned intIdx = 0; - for (const auto &argument : typeArguments()) { - std::string expr; - if (argument.type->isTypeArg()) { - expr = tgfmt("type_params()[$0]", &fmt, typeIdx); - ++typeIdx; - } else { - expr = tgfmt("int_params()[$0]", &fmt, intIdx); - expr = tgfmt(cast(argument.type)->getFromUnsigned(), &fmt, expr); - ++intIdx; + auto getterArgs = + ArrayRef(m_getterArguments).drop_front(m_argBegin); + + for (const auto &[argument, getterArg] : + llvm::zip(typeArguments(), getterArgs)) { + if (auto *attr = dyn_cast(argument.type)) { + out << tgfmt(attr->getCheck(), &fmt, getterArg.name) << '\n'; + } } - FmtContextScope scope{fmt}; - fmt.addSubst("type", argument.type->getCppType()); - fmt.addSubst("name", convertToCamelFromSnakeCase(argument.name, true)); - fmt.addSubst("expr", expr); + out << tgfmt( + " std::string $name; ::llvm::raw_string_ostream $os($name);\n", &fmt); + out << tgfmt(" $os << \"$0\";\n", &fmt, m_mnemonic); + for (const auto &getterArg : getterArgs) + out << tgfmt(" $os << '.' << (uint64_t)$0;\n", &fmt, getterArg.name); + + out << tgfmt(" ::std::vector<::llvm::Type*> $fields;\n", &fmt); + out << tgfmt( + " $fields.push_back(::llvm::IntegerType::get($_context, $0));\n", &fmt, + Twine(m_structSentinelBitWidth)); + + for (const auto &getterArg : getterArgs) { + out << tgfmt(R"( + if ($0 == 0) + $fields.push_back(::llvm::StructType::get($_context)); + else + $fields.push_back(::llvm::IntegerType::get($_context, $0)); +)", + &fmt, getterArg.name); + } + out << tgfmt(" auto *$st = ::llvm::StructType::create($_context, " + "$fields, $os.str(), /*isPacked=*/false);\n", + &fmt); + out << tgfmt(" return static_cast<$_type *>($st);\n}\n\n", &fmt); out << tgfmt(R"( +bool $_type::classof(const ::llvm::Type *t) { + auto *st = ::llvm::dyn_cast<::llvm::StructType>(t); + if (!st) + return false; + return st->getNumElements() && + st->getElementType(0)->isIntegerTy($0); +} +)", + &fmt, Twine(m_structSentinelBitWidth)); + } else { + // TargetExtType + + // Output the type argument getters. + unsigned typeIdx = 0; + unsigned intIdx = 0; + for (const auto &argument : typeArguments()) { + std::string expr; + if (argument.type->isTypeArg()) { + expr = tgfmt("type_params()[$0]", &fmt, typeIdx); + ++typeIdx; + } else { + expr = tgfmt("int_params()[$0]", &fmt, intIdx); + expr = tgfmt(cast(argument.type)->getFromUnsigned(), &fmt, expr); + ++intIdx; + } + + FmtContextScope scope{fmt}; + fmt.addSubst("type", argument.type->getCppType()); + fmt.addSubst("name", convertToCamelFromSnakeCase(argument.name, true)); + fmt.addSubst("expr", expr); + + out << tgfmt(R"( $type $_type::get$name() const { return $expr; } )", - &fmt, expr); - } + &fmt, expr); + } - // Output the default getter. - out << tgfmt("$_type* $_type::get(", &fmt); - for (auto argument : llvm::enumerate(m_getterArguments)) { - if (argument.index() != 0) - out << ", "; - out << argument.value().cppType << ' ' << argument.value().name; - } - out << ") {\n"; + // Output the default getter. + out << tgfmt("$_type* $_type::get(", &fmt); + for (auto argument : llvm::enumerate(m_getterArguments)) { + if (argument.index() != 0) + out << ", "; + out << argument.value().cppType << ' ' << argument.value().name; + } + out << ") {\n"; - out << m_prelude; + out << m_prelude; - auto getterArgs = - ArrayRef(m_getterArguments).drop_front(m_argBegin); + auto getterArgs = + ArrayRef(m_getterArguments).drop_front(m_argBegin); - for (const auto &[argument, getterArg] : - llvm::zip(typeArguments(), getterArgs)) { - if (auto *attr = dyn_cast(argument.type)) { - out << tgfmt(attr->getCheck(), &fmt, getterArg.name) << '\n'; + for (const auto &[argument, getterArg] : + llvm::zip(typeArguments(), getterArgs)) { + if (auto *attr = dyn_cast(argument.type)) { + out << tgfmt(attr->getCheck(), &fmt, getterArg.name) << '\n'; + } } - } - out << tgfmt("::std::array<::llvm::Type *, $0> $types = {\n", &fmt, typeIdx); - for (const auto &[argument, getterArg] : - llvm::zip(typeArguments(), getterArgs)) { - if (argument.type->isTypeArg()) - out << getterArg.name << ",\n"; - } - out << tgfmt(R"( + out << tgfmt("::std::array<::llvm::Type *, $0> $types = {\n", &fmt, + typeIdx); + for (const auto &[argument, getterArg] : + llvm::zip(typeArguments(), getterArgs)) { + if (argument.type->isTypeArg()) + out << getterArg.name << ",\n"; + } + out << tgfmt(R"( }; ::std::array $ints = { )", - &fmt, intIdx); - for (const auto &[argument, getterArg] : - llvm::zip(typeArguments(), getterArgs)) { - if (!argument.type->isTypeArg()) { - std::string expr = tgfmt(cast(argument.type)->getToUnsigned(), &fmt, - getterArg.name); - out << expr << ",\n"; + &fmt, intIdx); + for (const auto &[argument, getterArg] : + llvm::zip(typeArguments(), getterArgs)) { + if (!argument.type->isTypeArg()) { + std::string expr = tgfmt(cast(argument.type)->getToUnsigned(), + &fmt, getterArg.name); + out << expr << ",\n"; + } } - } - out << tgfmt(R"( + out << tgfmt(R"( }; auto *$type = ::llvm::cast<$_type>(::llvm::TargetExtType::get($_context, s_name, $types, $ints)); @@ -274,10 +414,10 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { return $type; } )", - &fmt); + &fmt); - // Output the verifier. - out << tgfmt(R"( + // Output the verifier. + out << tgfmt(R"( bool $_type::verifier(::llvm::raw_ostream &$_errs) const { ::llvm::LLVMContext &$_context = getContext(); (void)$_context; @@ -298,23 +438,24 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { return false; } )", - &fmt, typeIdx, intIdx); + &fmt, typeIdx, intIdx); - Assignment assignment; - Evaluator eval(symbols, assignment, m_system, out, fmt); + Assignment assignment; + Evaluator eval(symbols, assignment, m_system, out, fmt); - for (const auto &[argument, getterArg] : - llvm::zip(typeArguments(), getterArgs)) { - FmtContextScope scope{fmt}; - fmt.addSubst("getter", convertToCamelFromSnakeCase(argument.name, true)); - fmt.addSubst("name", getterArg.name); - out << tgfmt("auto $name = get$getter();\n(void)$name;\n", &fmt); + for (const auto &[argument, getterArg] : + llvm::zip(typeArguments(), getterArgs)) { + FmtContextScope scope{fmt}; + fmt.addSubst("getter", convertToCamelFromSnakeCase(argument.name, true)); + fmt.addSubst("name", getterArg.name); + out << tgfmt("auto $name = get$getter();\n(void)$name;\n", &fmt); - auto variable = m_scope.findVariable(argument.name); - assignment.assign(variable, fmt.getSubstFor("name").value()); - } + auto variable = m_scope.findVariable(argument.name); + assignment.assign(variable, fmt.getSubstFor("name").value()); + } - eval.check(true); + eval.check(true); - out << "return true;\n}\n\n"; + out << "return true;\n}\n\n"; + } } diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 6f817b8..2a39b5b 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -67,6 +67,16 @@ namespace xd::cpp { state.setError(); }); + builder.add([](::llvm_dialects::VerifierState &state, DummyStructBackedInpOp &op) { + if (!op.verifier(state.out())) + state.setError(); + }); + + builder.add([](::llvm_dialects::VerifierState &state, DummyStructBackedOutpOp &op) { + if (!op.verifier(state.out())) + state.setError(); + }); + builder.add([](::llvm_dialects::VerifierState &state, ExtractElementOp &op) { if (!op.verifier(state.out())) state.setError(); @@ -248,6 +258,44 @@ m_attributeLists[6] = argAttrList.addFnAttributes(context, attrBuilder); } } +StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2) { + + + + std::string name; ::llvm::raw_string_ostream os(name); + os << "struct.backed"; + os << '.' << (uint64_t)field0; + os << '.' << (uint64_t)field1; + os << '.' << (uint64_t)field2; + ::std::vector<::llvm::Type*> fields; + fields.push_back(::llvm::IntegerType::get(ctx, 41)); + + if (field0 == 0) + fields.push_back(::llvm::StructType::get(ctx)); + else + fields.push_back(::llvm::IntegerType::get(ctx, field0)); + + if (field1 == 0) + fields.push_back(::llvm::StructType::get(ctx)); + else + fields.push_back(::llvm::IntegerType::get(ctx, field1)); + + if (field2 == 0) + fields.push_back(::llvm::StructType::get(ctx)); + else + fields.push_back(::llvm::IntegerType::get(ctx, field2)); + auto *st = ::llvm::StructType::create(ctx, fields, os.str(), /*isPacked=*/false); + return static_cast(st); +} + + +bool StructBackedType::classof(const ::llvm::Type *t) { + auto *st = ::llvm::dyn_cast<::llvm::StructType>(t); + if (!st) + return false; + return st->getNumElements() && + st->getElementType(0)->isIntegerTy(41); +} XdHandleType* XdHandleType::get(::llvm::LLVMContext & ctx) { ::std::array<::llvm::Type *, 0> types = { @@ -704,6 +752,177 @@ rhs + const ::llvm::StringLiteral DummyStructBackedInpOp::s_name{"xd.ir.struct.backed.inp.op"}; + + DummyStructBackedInpOp* DummyStructBackedInpOp::create(llvm_dialects::Builder& b, ::llvm::Value * inp, const llvm::Twine &instName) { + ::llvm::LLVMContext& context = b.getContext(); + (void)context; + ::llvm::Module& module = *b.GetInsertBlock()->getModule(); + + + const ::llvm::AttributeList attrs + = ExampleDialect::get(context).getAttributeList(6); + auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), true); + + auto fn = module.getOrInsertFunction(s_name, fnType, attrs); + ::llvm::SmallString<32> newName; + for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) || + ::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) { + // If a function with the same name but a different types already exists, + // we get a bitcast of a function or a function with the wrong type. + // Try new names until we get one with the correct type. + newName = ""; + ::llvm::raw_svector_ostream newNameStream(newName); + newNameStream << s_name << "_" << i; + fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs); + } + assert(::llvm::isa<::llvm::Function>(fn.getCallee())); + assert(fn.getFunctionType() == fnType); + assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType()); + + ::llvm::SmallVector<::llvm::Value*, 1> args = { +inp + }; + + return ::llvm::cast(b.CreateCall(fn, args, instName)); + } + + + bool DummyStructBackedInpOp::verifier(::llvm::raw_ostream &errs) { + ::llvm::LLVMContext &context = getModule()->getContext(); + (void)context; + + using ::llvm_dialects::printable; + + if (arg_size() != 1) { + errs << " wrong number of arguments: " << arg_size() + << ", expected 1\n"; + return false; + } + ::llvm::Type * const inpType = getInp()->getType(); +(void)inpType; +::llvm::Type * const retType = getRet()->getType(); +(void)retType; + + if (::llvm::IntegerType::get(context, 32) != retType) { + errs << " unexpected value of $ret:\n"; + errs << " expected: " << printable(::llvm::IntegerType::get(context, 32)) << '\n'; + errs << " actual: " << printable(retType) << '\n'; + + return false; + } + + if (!(::llvm::isa<::llvm::StructType>(inpType))) { + errs << " failed check for StructBackedType:$inp\n"; + + errs << " with $inp = " << printable(inpType) << '\n'; + + + return false; + } + return true; +} + + + ::llvm::Value * DummyStructBackedInpOp::getInp() const { + return getArgOperand(ArgumentIndex::Inp); + } + + void DummyStructBackedInpOp::setInp(::llvm::Value * inp) { + setArgOperand(ArgumentIndex::Inp, inp); + } +::llvm::Value *DummyStructBackedInpOp::getRet() {return this;} + + + + const ::llvm::StringLiteral DummyStructBackedOutpOp::s_name{"xd.ir.struct.backed.outp.op"}; + + DummyStructBackedOutpOp* DummyStructBackedOutpOp::create(llvm_dialects::Builder& b, ::llvm::Type* outpType, ::llvm::Value * extraArg, const llvm::Twine &instName) { + ::llvm::LLVMContext& context = b.getContext(); + (void)context; + ::llvm::Module& module = *b.GetInsertBlock()->getModule(); + + + const ::llvm::AttributeList attrs + = ExampleDialect::get(context).getAttributeList(6); + + std::string mangledName = + ::llvm_dialects::getMangledName(s_name, {outpType}); + auto fnType = ::llvm::FunctionType::get(outpType, { +extraArg->getType(), +}, false); + + auto fn = module.getOrInsertFunction(mangledName, fnType, attrs); + ::llvm::SmallString<32> newName; + for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) || + ::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) { + // If a function with the same name but a different types already exists, + // we get a bitcast of a function or a function with the wrong type. + // Try new names until we get one with the correct type. + newName = ""; + ::llvm::raw_svector_ostream newNameStream(newName); + newNameStream << mangledName << "_" << i; + fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs); + } + assert(::llvm::isa<::llvm::Function>(fn.getCallee())); + assert(fn.getFunctionType() == fnType); + assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType()); + + ::llvm::SmallVector<::llvm::Value*, 1> args = { +extraArg + }; + + return ::llvm::cast(b.CreateCall(fn, args, instName)); + } + + + bool DummyStructBackedOutpOp::verifier(::llvm::raw_ostream &errs) { + ::llvm::LLVMContext &context = getModule()->getContext(); + (void)context; + + using ::llvm_dialects::printable; + + if (arg_size() != 1) { + errs << " wrong number of arguments: " << arg_size() + << ", expected 1\n"; + return false; + } + ::llvm::Type * const extraArgType = getExtraArg()->getType(); +(void)extraArgType; +::llvm::Type * const outpType = getOutp()->getType(); +(void)outpType; + + if (::llvm::IntegerType::get(context, 32) != extraArgType) { + errs << " unexpected value of $extra_arg:\n"; + errs << " expected: " << printable(::llvm::IntegerType::get(context, 32)) << '\n'; + errs << " actual: " << printable(extraArgType) << '\n'; + + return false; + } + + if (!(::llvm::isa<::llvm::StructType>(outpType))) { + errs << " failed check for StructBackedType:$outp\n"; + + errs << " with $outp = " << printable(outpType) << '\n'; + + + return false; + } + return true; +} + + + ::llvm::Value * DummyStructBackedOutpOp::getExtraArg() const { + return getArgOperand(ArgumentIndex::ExtraArg); + } + + void DummyStructBackedOutpOp::setExtraArg(::llvm::Value * extra_arg) { + setArgOperand(ArgumentIndex::ExtraArg, extra_arg); + } +::llvm::Value *DummyStructBackedOutpOp::getOutp() {return this;} + + + const ::llvm::StringLiteral ExtractElementOp::s_name{"xd.ir.extractelement"}; ExtractElementOp* ExtractElementOp::create(llvm_dialects::Builder& b, ::llvm::Value * vector, ::llvm::Value * index, const llvm::Twine &instName) { @@ -2531,6 +2750,22 @@ data } + template <> + const ::llvm_dialects::OpDescription & + ::llvm_dialects::OpDescription::get() { + static const ::llvm_dialects::OpDescription desc{false, "xd.ir.struct.backed.inp.op"}; + return desc; + } + + + template <> + const ::llvm_dialects::OpDescription & + ::llvm_dialects::OpDescription::get() { + static const ::llvm_dialects::OpDescription desc{true, "xd.ir.struct.backed.outp.op"}; + return desc; + } + + template <> const ::llvm_dialects::OpDescription & ::llvm_dialects::OpDescription::get() { diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index cc1874e..46189b9 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -50,6 +50,39 @@ namespace xd::cpp { ::std::array<::llvm::AttributeList, 7> m_attributeLists; }; + class StructBackedType : public ::llvm::StructType { + using ::llvm::StructType::StructType; + public: + static constexpr ::llvm::StringLiteral s_prefix{"xd.ir.struct.backed."}; + + using ::llvm::StructType::getElementType; + + static StructBackedType *get( + ::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2); + + static bool classof(const ::llvm::Type *t); + + unsigned getField0() const { + ::llvm::Type *elt = getElementType(1); + if (elt->isStructTy()) + return 0; + return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + } + unsigned getField1() const { + ::llvm::Type *elt = getElementType(2); + if (elt->isStructTy()) + return 0; + return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + } + unsigned getField2() const { + ::llvm::Type *elt = getElementType(3); + if (elt->isStructTy()) + return 0; + return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + } + }; + + class XdHandleType : public ::llvm::TargetExtType { static constexpr ::llvm::StringLiteral s_name{"xd.ir.handle"}; @@ -225,6 +258,72 @@ Rhs = 1, ::llvm::Value * getResult(); + }; + + /// DummyStructBackedInpOp +/// a custom op using input arg with struct-backed type +/// +/// Test that an operation that takes argument with struct-backed type works correctly. +/// +/// Arguments +/// * Value * inp + + class DummyStructBackedInpOp : public ::llvm::CallInst { + static const ::llvm::StringLiteral s_name; //{"xd.ir.struct.backed.inp.op"}; + + public: + static bool classof(const ::llvm::CallInst* i) { + return ::llvm_dialects::detail::isSimpleOperation(i, s_name); + } + static bool classof(const ::llvm::Value* v) { + return ::llvm::isa<::llvm::CallInst>(v) && + classof(::llvm::cast<::llvm::CallInst>(v)); + } + static DummyStructBackedInpOp* create(::llvm_dialects::Builder& b, ::llvm::Value * inp, const llvm::Twine &instName = ""); + +bool verifier(::llvm::raw_ostream &errs); + +::llvm::Value * getInp() const; + void setInp(::llvm::Value * inp); + struct ArgumentIndex { enum Enum : uint32_t { +Inp = 0, +};}; +::llvm::Value * getRet(); + + + }; + + /// DummyStructBackedOutpOp +/// a custom op returning value with struct-backed type +/// +/// Test that an operation that returns value with struct-backed type works correctly. +/// +/// Arguments +/// * Value * extra_arg + + class DummyStructBackedOutpOp : public ::llvm::CallInst { + static const ::llvm::StringLiteral s_name; //{"xd.ir.struct.backed.outp.op"}; + + public: + static bool classof(const ::llvm::CallInst* i) { + return ::llvm_dialects::detail::isOverloadedOperation(i, s_name); + } + static bool classof(const ::llvm::Value* v) { + return ::llvm::isa<::llvm::CallInst>(v) && + classof(::llvm::cast<::llvm::CallInst>(v)); + } + static DummyStructBackedOutpOp* create(::llvm_dialects::Builder& b, ::llvm::Type* outpType, ::llvm::Value * extraArg, const llvm::Twine &instName = ""); + +bool verifier(::llvm::raw_ostream &errs); + +::llvm::Value * getExtraArg() const; + void setExtraArg(::llvm::Value * extra_arg); + struct ArgumentIndex { enum Enum : uint32_t { +ExtraArg = 0, +};}; +::llvm::Value * getOutp(); + + }; /// ExtractElementOp diff --git a/test/example/test-builder.test b/test/example/test-builder.test index 941ed8f..5f88cc3 100644 --- a/test/example/test-builder.test +++ b/test/example/test-builder.test @@ -24,8 +24,8 @@ ;. ; CHECK: @str = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1 ;. -; CHECK-LABEL: @example( -; CHECK-NEXT: entry: +; CHECK-LABEL: define void @example() { +; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[TMP0:%.*]] = call i32 @xd.ir.read__i32() ; CHECK-NEXT: [[TMP1:%.*]] = call i64 (...) @xd.ir.sizeof(double poison) ; CHECK-NEXT: [[TMP2:%.*]] = call i32 (...) @xd.ir.itrunc__i32(i64 [[TMP1]]) @@ -66,6 +66,8 @@ ; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.ir.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3) ; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.ir.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4) ; CHECK-NEXT: call void @xd.ir.string.attr.op(ptr @str) +; CHECK-NEXT: [[GEN_STRUCT_BACKED_VAL:%.*]] = call [[SB_1_0_2_:%.*]] @[[XD_IR_STRUCT_BACKED_OUTP_OP__S_SB_1_0_2_S:[a-zA-Z0-9_$\"\\.-]*[a-zA-Z_$\"\\.-][a-zA-Z0-9_$\"\\.-]*]](i32 42) +; CHECK-NEXT: [[CONSUME_STRUCT_BACKED_VAL:%.*]] = call i32 (...) @xd.ir.struct.backed.inp.op([[SB_1_0_2_]] [[GEN_STRUCT_BACKED_VAL]]) ; CHECK-NEXT: ret void ; ;.