Skip to content

Commit a879be8

Browse files
krzysz00gysit
andauthored
[mlir][LLVMIR] Support memory model relaxation annotations (MMRA) (#157770)
This commit adds support for exportind and importing MMRA data in the LLVM dialect. MMRA is a potentilly-discardable piece of metadata that can be placed on any operation that touches memory (fences, loads, stores, atomics, and intrinsics that operate on memory). It includes one (technically zero) ome more prefix:suffix string pairs which indicate ways in which the LLVM memory model can be relaxed for these annotations. At the MLIR level, each tag is represented with a `#llvm.mmra_tag<"prefix":"suffex">` attribute, and the MMRA metadata as a whole is represented as a discardable llvm.mmra attribute. (This discardability both allows us to transparently enable MMRA for wrapper dialects like ROCDL and ensures that MLIR passes which don't know about MMRA combining will, conservatively, discard the annotations, per the LLVM spec). --------- Co-authored-by: Tobias Gysi <[email protected]>
1 parent 110ab5a commit a879be8

File tree

7 files changed

+210
-1
lines changed

7 files changed

+210
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,47 @@ def LLVM_TBAATagArrayAttr
12321232
let constBuilderCall = ?;
12331233
}
12341234

1235+
//===----------------------------------------------------------------------===//
1236+
// MMRATagAttr
1237+
//===----------------------------------------------------------------------===//
1238+
1239+
def LLVM_MMRATagAttr : LLVM_Attr<"MMRATag", "mmra_tag"> {
1240+
let parameters = (ins
1241+
StringRefParameter<>:$prefix,
1242+
StringRefParameter<>:$suffix
1243+
);
1244+
1245+
let summary = "MLIR wrapper around a prefix:suffix MMRA tag";
1246+
1247+
let description = [{
1248+
Defines a single memory model relaxation annotation (MMRA) entry
1249+
with prefix `$prefix` and suffix `$suffix`. This corresponds directly
1250+
to a LLVM `!{prefix, suffix}` metadata tuple, which is often written
1251+
`prefix:shuffix` as shorthand.
1252+
1253+
Example:
1254+
```mlir
1255+
#mmra_tag = #llvm.mmmra_tag<"amdgpu-synchronize-as":"local">
1256+
#mmra_tag1 = #llvm.mmra_tag<"foo":"bar">
1257+
```
1258+
1259+
Either one MMRA tag or an array of them may be added to any LLVM
1260+
operation that operates on memory.
1261+
1262+
```mlir
1263+
%v = llvm.load %ptr {llvm.mmra = #mmra_tag} : !llvm.ptr -> i8
1264+
llvm.store %v, %ptr2 {llvm.mmra [#mmra_tag, #mmra_tag1]} : i8, !llvm.ptr
1265+
```
1266+
1267+
See the following link for more details:
1268+
https://llvm.org/docs/MemoryModelRelaxationAnnotations.html
1269+
}];
1270+
1271+
let assemblyFormat = "`<` $prefix `` `:` `` $suffix `>`";
1272+
1273+
let genMnemonicAlias = 1;
1274+
}
1275+
12351276
//===----------------------------------------------------------------------===//
12361277
// ConstantRangeAttr
12371278
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def LLVM_Dialect : Dialect {
3636
static StringRef getIdentAttrName() { return "llvm.ident"; }
3737
static StringRef getModuleFlags() { return "llvm.module.flags"; }
3838
static StringRef getCommandlineAttrName() { return "llvm.commandline"; }
39+
static StringRef getMmraAttrName() { return "llvm.mmra"; }
3940

4041
/// Names of llvm parameter attributes.
4142
static StringRef getAlignAttrName() { return "llvm.align"; }

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1415
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
1617
#include "mlir/Support/LLVM.h"
@@ -21,6 +22,7 @@
2122
#include "llvm/IR/InlineAsm.h"
2223
#include "llvm/IR/Instructions.h"
2324
#include "llvm/IR/IntrinsicInst.h"
25+
#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
2426

2527
using namespace mlir;
2628
using namespace mlir::LLVM;
@@ -88,6 +90,7 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
8890
llvm::LLVMContext::MD_alias_scope,
8991
llvm::LLVMContext::MD_dereferenceable,
9092
llvm::LLVMContext::MD_dereferenceable_or_null,
93+
llvm::LLVMContext::MD_mmra,
9194
context.getMDKindID(vecTypeHintMDName),
9295
context.getMDKindID(workGroupSizeHintMDName),
9396
context.getMDKindID(reqdWorkGroupSizeMDName),
@@ -212,6 +215,39 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
212215
return success();
213216
}
214217

218+
/// Convert the given MMRA metadata (either an MMRA tag or an array of them)
219+
/// into corresponding MLIR attributes and set them on the given operation as a
220+
/// discardable `llvm.mmra` attribute.
221+
static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op,
222+
LLVM::ModuleImport &moduleImport) {
223+
if (!node)
224+
return success();
225+
226+
// We don't use the LLVM wrappers here becasue we care about the order
227+
// of the metadata for deterministic roundtripping.
228+
MLIRContext *ctx = op->getContext();
229+
auto toAttribute = [&](llvm::MDNode *tag) -> Attribute {
230+
return LLVM::MMRATagAttr::get(
231+
ctx, cast<llvm::MDString>(tag->getOperand(0))->getString(),
232+
cast<llvm::MDString>(tag->getOperand(1))->getString());
233+
};
234+
Attribute mlirMmra;
235+
if (llvm::MMRAMetadata::isTagMD(node)) {
236+
mlirMmra = toAttribute(node);
237+
} else {
238+
SmallVector<Attribute> tags;
239+
for (const llvm::MDOperand &operand : node->operands()) {
240+
auto *tagNode = dyn_cast<llvm::MDNode>(operand.get());
241+
if (!tagNode || !llvm::MMRAMetadata::isTagMD(tagNode))
242+
return failure();
243+
tags.push_back(toAttribute(tagNode));
244+
}
245+
mlirMmra = ArrayAttr::get(ctx, tags);
246+
}
247+
op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra);
248+
return success();
249+
}
250+
215251
/// Converts the given loop metadata node to an MLIR loop annotation attribute
216252
/// and attaches it to the imported operation if the translation succeeds.
217253
/// Returns failure otherwise.
@@ -432,7 +468,8 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
432468
return setDereferenceableAttr(
433469
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
434470
moduleImport);
435-
471+
if (kind == llvm::LLVMContext::MD_mmra)
472+
return setMmraAttr(node, op, moduleImport);
436473
llvm::LLVMContext &context = node->getContext();
437474
if (kind == context.getMDKindID(vecTypeHintMDName))
438475
return setVecTypeHintAttr(builder, node, op, moduleImport);

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "llvm/IR/Instructions.h"
2525
#include "llvm/IR/MDBuilder.h"
2626
#include "llvm/IR/MatrixBuilder.h"
27+
#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
28+
#include "llvm/Support/LogicalResult.h"
2729

2830
using namespace mlir;
2931
using namespace mlir::LLVM;
@@ -723,6 +725,40 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
723725
return failure();
724726
}
725727

728+
static LogicalResult
729+
amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
730+
NamedAttribute attribute,
731+
LLVM::ModuleTranslation &moduleTranslation) {
732+
StringRef name = attribute.getName();
733+
if (name == LLVMDialect::getMmraAttrName()) {
734+
SmallVector<llvm::MMRAMetadata::TagT> tags;
735+
if (auto oneTag = dyn_cast<LLVM::MMRATagAttr>(attribute.getValue())) {
736+
tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix());
737+
} else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) {
738+
for (Attribute attr : manyTags) {
739+
auto tag = dyn_cast<MMRATagAttr>(attr);
740+
if (!tag)
741+
return op.emitOpError(
742+
"MMRA annotations array contains value that isn't an MMRA tag");
743+
tags.emplace_back(tag.getPrefix(), tag.getSuffix());
744+
}
745+
} else {
746+
return op.emitOpError(
747+
"llvm.mmra is something other than an MMRA tag or an array of them");
748+
}
749+
llvm::MDTuple *mmraMd =
750+
llvm::MMRAMetadata::getMD(moduleTranslation.getLLVMContext(), tags);
751+
if (!mmraMd) {
752+
// Empty list, canonicalizes to nothing
753+
return success();
754+
}
755+
for (llvm::Instruction *inst : instructions)
756+
inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd);
757+
return success();
758+
}
759+
return success();
760+
}
761+
726762
namespace {
727763
/// Implementation of the dialect interface that converts operations belonging
728764
/// to the LLVM dialect to LLVM IR.
@@ -738,6 +774,14 @@ class LLVMDialectLLVMIRTranslationInterface
738774
LLVM::ModuleTranslation &moduleTranslation) const final {
739775
return convertOperationImpl(*op, builder, moduleTranslation);
740776
}
777+
778+
/// Handle some metadata that is represented as a discardable attribute.
779+
LogicalResult
780+
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
781+
NamedAttribute attribute,
782+
LLVM::ModuleTranslation &moduleTranslation) const final {
783+
return amendOperationImpl(*op, instructions, attribute, moduleTranslation);
784+
}
741785
};
742786
} // namespace
743787

mlir/test/Dialect/LLVMIR/mmra.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: mlir-opt %s -split-input-file --verify-roundtrip --mlir-print-local-scope | FileCheck %s
2+
3+
// CHECK-LABEL: llvm.func @native
4+
// CHECK: llvm.load
5+
// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">
6+
// CHECK: llvm.fence
7+
// CHECK-SAME: llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #llvm.mmra_tag<"foo":"bar">]
8+
// CHECK: llvm.store
9+
// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">
10+
11+
#mmra_tag = #llvm.mmra_tag<"foo":"bar">
12+
13+
llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
14+
%0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
15+
llvm.fence syncscope("workgroup-one-as") release
16+
{llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
17+
llvm.store %0, %y {llvm.mmra = #llvm.mmra_tag<"foo":"bar">} : i32, !llvm.ptr
18+
llvm.return
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: llvm.func @foreign_op
24+
// CHECK: rocdl.load.to.lds
25+
// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"fake":"example">
26+
llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
27+
rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
28+
llvm.return
29+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
; CHECK-DAG: #[[$MMRA0:.+]] = #llvm.mmra_tag<"foo":"bar">
4+
; CHECK-DAG: #[[$MMRA1:.+]] = #llvm.mmra_tag<"amdgpu-synchronize-as":"local">
5+
6+
; CHECK-LABEL: llvm.func @native
7+
define void @native(ptr %x, ptr %y) {
8+
; CHECK: llvm.load
9+
; CHECK-SAME: llvm.mmra = #[[$MMRA0]]
10+
%v = load i32, ptr %x, align 4, !mmra !0
11+
; CHECK: llvm.fence
12+
; CHECK-SAME: llvm.mmra = [#[[$MMRA1]], #[[$MMRA0]]]
13+
fence syncscope("workgroup-one-as") release, !mmra !2
14+
; CHECK: llvm.store {{.*}}, !llvm.ptr{{$}}
15+
store i32 %v, ptr %y, align 4, !mmra !3
16+
ret void
17+
}
18+
19+
!0 = !{!"foo", !"bar"}
20+
!1 = !{!"amdgpu-synchronize-as", !"local"}
21+
!2 = !{!1, !0}
22+
!3 = !{}

mlir/test/Target/LLVMIR/mmra.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @native
4+
// CHECK: load
5+
// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
6+
// CHECK: fence
7+
// CHECK-SAME: !mmra ![[MMRA1:[0-9]+]]
8+
// CHECK: store {{.*}}, align 4{{$}}
9+
10+
#mmra_tag = #llvm.mmra_tag<"foo":"bar">
11+
12+
llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
13+
%0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
14+
llvm.fence syncscope("workgroup-one-as") release
15+
{llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
16+
llvm.store %0, %y {llvm.mmra = []} : i32, !llvm.ptr
17+
llvm.return
18+
}
19+
20+
// Actual MMRA metadata
21+
// CHECK-DAG: ![[MMRA0]] = !{!"foo", !"bar"}
22+
// CHECK-DAG: ![[MMRA_PART0:[0-9]+]] = !{!"amdgpu-synchronize-as", !"local"}
23+
// CHECK-DAG: ![[MMRA1]] = !{![[MMRA_PART0]], ![[MMRA0]]}
24+
25+
// -----
26+
27+
// CHECK-LABEL: define void @foreign_op
28+
// CHECK: call void @llvm.amdgcn.load.to.lds
29+
// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
30+
llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
31+
rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
32+
llvm.return
33+
}
34+
35+
// CHECK: ![[MMRA0]] = !{!"fake", !"example"}

0 commit comments

Comments
 (0)