|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
13 | 13 | #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
|
| 14 | +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
14 | 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
15 | 16 | #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
|
16 | 17 | #include "mlir/Support/LLVM.h"
|
|
21 | 22 | #include "llvm/IR/InlineAsm.h"
|
22 | 23 | #include "llvm/IR/Instructions.h"
|
23 | 24 | #include "llvm/IR/IntrinsicInst.h"
|
| 25 | +#include "llvm/IR/MemoryModelRelaxationAnnotations.h" |
24 | 26 |
|
25 | 27 | using namespace mlir;
|
26 | 28 | using namespace mlir::LLVM;
|
@@ -88,6 +90,7 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
|
88 | 90 | llvm::LLVMContext::MD_alias_scope,
|
89 | 91 | llvm::LLVMContext::MD_dereferenceable,
|
90 | 92 | llvm::LLVMContext::MD_dereferenceable_or_null,
|
| 93 | + llvm::LLVMContext::MD_mmra, |
91 | 94 | context.getMDKindID(vecTypeHintMDName),
|
92 | 95 | context.getMDKindID(workGroupSizeHintMDName),
|
93 | 96 | context.getMDKindID(reqdWorkGroupSizeMDName),
|
@@ -212,6 +215,39 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
|
212 | 215 | return success();
|
213 | 216 | }
|
214 | 217 |
|
| 218 | +/// Convert the given MMRA metadata (either an MMRA tag or an array of them) |
| 219 | +/// into corresponding MLIR attributes and set them on the given operation as a |
| 220 | +/// discardable `llvm.mmra` attribute. |
| 221 | +static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op, |
| 222 | + LLVM::ModuleImport &moduleImport) { |
| 223 | + if (!node) |
| 224 | + return success(); |
| 225 | + |
| 226 | + // We don't use the LLVM wrappers here becasue we care about the order |
| 227 | + // of the metadata for deterministic roundtripping. |
| 228 | + MLIRContext *ctx = op->getContext(); |
| 229 | + auto toAttribute = [&](llvm::MDNode *tag) -> Attribute { |
| 230 | + return LLVM::MMRATagAttr::get( |
| 231 | + ctx, cast<llvm::MDString>(tag->getOperand(0))->getString(), |
| 232 | + cast<llvm::MDString>(tag->getOperand(1))->getString()); |
| 233 | + }; |
| 234 | + Attribute mlirMmra; |
| 235 | + if (llvm::MMRAMetadata::isTagMD(node)) { |
| 236 | + mlirMmra = toAttribute(node); |
| 237 | + } else { |
| 238 | + SmallVector<Attribute> tags; |
| 239 | + for (const llvm::MDOperand &operand : node->operands()) { |
| 240 | + auto *tagNode = dyn_cast<llvm::MDNode>(operand.get()); |
| 241 | + if (!tagNode || !llvm::MMRAMetadata::isTagMD(tagNode)) |
| 242 | + return failure(); |
| 243 | + tags.push_back(toAttribute(tagNode)); |
| 244 | + } |
| 245 | + mlirMmra = ArrayAttr::get(ctx, tags); |
| 246 | + } |
| 247 | + op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra); |
| 248 | + return success(); |
| 249 | +} |
| 250 | + |
215 | 251 | /// Converts the given loop metadata node to an MLIR loop annotation attribute
|
216 | 252 | /// and attaches it to the imported operation if the translation succeeds.
|
217 | 253 | /// Returns failure otherwise.
|
@@ -432,7 +468,8 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
|
432 | 468 | return setDereferenceableAttr(
|
433 | 469 | node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
|
434 | 470 | moduleImport);
|
435 |
| - |
| 471 | + if (kind == llvm::LLVMContext::MD_mmra) |
| 472 | + return setMmraAttr(node, op, moduleImport); |
436 | 473 | llvm::LLVMContext &context = node->getContext();
|
437 | 474 | if (kind == context.getMDKindID(vecTypeHintMDName))
|
438 | 475 | return setVecTypeHintAttr(builder, node, op, moduleImport);
|
|
0 commit comments