Skip to content

Commit d15998f

Browse files
authored
[mlir][ptr] Add translations to LLVMIR for ptr ops. (llvm#156355)
Implements translation from ptr dialect to LLVM IR for core pointer operations: - `ptr.ptr_add` -> `getelementptr` - `ptr.load` -> `load` with atomic ordering, volatility, and metadata support - `ptr.store` -> `store` with atomic ordering, volatility, and metadata support - `ptr.type_offset` -> GEP-based size computation Example: ```mlir llvm.func @test(%arg0: !ptr.ptr<#llvm.address_space<0>>) { %0 = ptr.type_offset f64 : i32 %1 = ptr.ptr_add inbounds %arg0, %0 : !ptr.ptr<#llvm.address_space<0>>, i32 %2 = ptr.load volatile %1 : !ptr.ptr<#llvm.address_space<0>> -> f64 ptr.store %2, %arg0 : f64, !ptr.ptr<#llvm.address_space<0>> llvm.return } ``` Translates to: ```llvm define void @test(ptr %0) { %2 = getelementptr inbounds i8, ptr %0, i32 8 %3 = load volatile double, ptr %2, align 8 store double %3, ptr %0, align 8 ret void } ```
1 parent 86879d4 commit d15998f

File tree

3 files changed

+279
-10
lines changed

3 files changed

+279
-10
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def LLVM_NonLoadableTargetExtType : Type<
9595
// type that has size (not void, function, opaque struct type or target
9696
// extension type which does not support memory operations).
9797
def LLVM_LoadableType : Type<
98-
Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>,
99-
Neg<LLVM_NonLoadableTargetExtType.predicate>]>,
98+
Or<[CPred<"mlir::LLVM::isLoadableType($_self)">,
10099
LLVM_PointerElementTypeInterface.predicate]>,
101100
"LLVM type with size">;
102101

mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp

Lines changed: 203 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,194 @@
1616
#include "mlir/IR/BuiltinAttributes.h"
1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
19+
#include "llvm/ADT/TypeSwitch.h"
20+
#include "llvm/IR/IRBuilder.h"
21+
#include "llvm/IR/Instructions.h"
22+
#include "llvm/IR/Type.h"
23+
#include "llvm/IR/Value.h"
1924

2025
using namespace mlir;
2126
using namespace mlir::ptr;
2227

2328
namespace {
29+
30+
/// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
31+
static llvm::AtomicOrdering
32+
convertAtomicOrdering(ptr::AtomicOrdering ordering) {
33+
switch (ordering) {
34+
case ptr::AtomicOrdering::not_atomic:
35+
return llvm::AtomicOrdering::NotAtomic;
36+
case ptr::AtomicOrdering::unordered:
37+
return llvm::AtomicOrdering::Unordered;
38+
case ptr::AtomicOrdering::monotonic:
39+
return llvm::AtomicOrdering::Monotonic;
40+
case ptr::AtomicOrdering::acquire:
41+
return llvm::AtomicOrdering::Acquire;
42+
case ptr::AtomicOrdering::release:
43+
return llvm::AtomicOrdering::Release;
44+
case ptr::AtomicOrdering::acq_rel:
45+
return llvm::AtomicOrdering::AcquireRelease;
46+
case ptr::AtomicOrdering::seq_cst:
47+
return llvm::AtomicOrdering::SequentiallyConsistent;
48+
}
49+
llvm_unreachable("Unknown atomic ordering");
50+
}
51+
52+
/// Convert ptr.ptr_add operation
53+
static LogicalResult
54+
convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55+
LLVM::ModuleTranslation &moduleTranslation) {
56+
llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
57+
llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
58+
59+
if (!basePtr || !offset)
60+
return ptrAddOp.emitError("Failed to lookup operands");
61+
62+
// Create the GEP flags
63+
llvm::GEPNoWrapFlags gepFlags;
64+
switch (ptrAddOp.getFlags()) {
65+
case ptr::PtrAddFlags::none:
66+
break;
67+
case ptr::PtrAddFlags::nusw:
68+
gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap();
69+
break;
70+
case ptr::PtrAddFlags::nuw:
71+
gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap();
72+
break;
73+
case ptr::PtrAddFlags::inbounds:
74+
gepFlags = llvm::GEPNoWrapFlags::inBounds();
75+
break;
76+
}
77+
78+
// Create GEP instruction for pointer arithmetic
79+
llvm::Value *gep =
80+
builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset}, "", gepFlags);
81+
82+
moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
83+
return success();
84+
}
85+
86+
/// Convert ptr.load operation
87+
static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
88+
LLVM::ModuleTranslation &moduleTranslation) {
89+
llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
90+
if (!ptr)
91+
return loadOp.emitError("Failed to lookup pointer operand");
92+
93+
// Convert result type to LLVM type
94+
llvm::Type *resultType =
95+
moduleTranslation.convertType(loadOp.getValue().getType());
96+
if (!resultType)
97+
return loadOp.emitError("Failed to convert result type");
98+
99+
// Create the load instruction.
100+
llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
101+
llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
102+
resultType, ptr, alignment, loadOp.getVolatile_());
103+
104+
// Set op flags and metadata.
105+
loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
106+
// Set sync scope if specified
107+
if (loadOp.getSyncscope().has_value()) {
108+
llvm::LLVMContext &ctx = builder.getContext();
109+
llvm::SyncScope::ID syncScope =
110+
ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
111+
loadInst->setSyncScopeID(syncScope);
112+
}
113+
114+
// Set metadata for nontemporal, invariant, and invariant_group
115+
if (loadOp.getNontemporal()) {
116+
llvm::MDNode *nontemporalMD =
117+
llvm::MDNode::get(builder.getContext(),
118+
llvm::ConstantAsMetadata::get(builder.getInt32(1)));
119+
loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
120+
}
121+
122+
if (loadOp.getInvariant()) {
123+
llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {});
124+
loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
125+
}
126+
127+
if (loadOp.getInvariantGroup()) {
128+
llvm::MDNode *invariantGroupMD =
129+
llvm::MDNode::get(builder.getContext(), {});
130+
loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
131+
invariantGroupMD);
132+
}
133+
134+
moduleTranslation.mapValue(loadOp.getResult(), loadInst);
135+
return success();
136+
}
137+
138+
/// Convert ptr.store operation
139+
static LogicalResult
140+
convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
141+
LLVM::ModuleTranslation &moduleTranslation) {
142+
llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
143+
llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
144+
145+
if (!value || !ptr)
146+
return storeOp.emitError("Failed to lookup operands");
147+
148+
// Create the store instruction.
149+
llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0));
150+
llvm::StoreInst *storeInst =
151+
builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
152+
153+
// Set op flags and metadata.
154+
storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
155+
// Set sync scope if specified
156+
if (storeOp.getSyncscope().has_value()) {
157+
llvm::LLVMContext &ctx = builder.getContext();
158+
llvm::SyncScope::ID syncScope =
159+
ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
160+
storeInst->setSyncScopeID(syncScope);
161+
}
162+
163+
// Set metadata for nontemporal and invariant_group
164+
if (storeOp.getNontemporal()) {
165+
llvm::MDNode *nontemporalMD =
166+
llvm::MDNode::get(builder.getContext(),
167+
llvm::ConstantAsMetadata::get(builder.getInt32(1)));
168+
storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
169+
}
170+
171+
if (storeOp.getInvariantGroup()) {
172+
llvm::MDNode *invariantGroupMD =
173+
llvm::MDNode::get(builder.getContext(), {});
174+
storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
175+
invariantGroupMD);
176+
}
177+
178+
return success();
179+
}
180+
181+
/// Convert ptr.type_offset operation
182+
static LogicalResult
183+
convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
184+
LLVM::ModuleTranslation &moduleTranslation) {
185+
// Convert the element type to LLVM type
186+
llvm::Type *elementType =
187+
moduleTranslation.convertType(typeOffsetOp.getElementType());
188+
if (!elementType)
189+
return typeOffsetOp.emitError("Failed to convert the element type");
190+
191+
// Convert result type
192+
llvm::Type *resultType =
193+
moduleTranslation.convertType(typeOffsetOp.getResult().getType());
194+
if (!resultType)
195+
return typeOffsetOp.emitError("Failed to convert the result type");
196+
197+
// Use GEP with null pointer to compute type size/offset.
198+
llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
199+
llvm::Value *offsetPtr =
200+
builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
201+
llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
202+
203+
moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
204+
return success();
205+
}
206+
24207
/// Implementation of the dialect interface that converts operations belonging
25208
/// to the `ptr` dialect to LLVM IR.
26209
class PtrDialectLLVMIRTranslationInterface
@@ -33,21 +216,33 @@ class PtrDialectLLVMIRTranslationInterface
33216
LogicalResult
34217
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
35218
LLVM::ModuleTranslation &moduleTranslation) const final {
36-
// Translation for ptr dialect operations to LLVM IR is currently
37-
// unimplemented.
38-
return op->emitError("Translation for ptr dialect operations to LLVM IR is "
39-
"not implemented.");
219+
220+
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
221+
.Case([&](PtrAddOp ptrAddOp) {
222+
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
223+
})
224+
.Case([&](LoadOp loadOp) {
225+
return convertLoadOp(loadOp, builder, moduleTranslation);
226+
})
227+
.Case([&](StoreOp storeOp) {
228+
return convertStoreOp(storeOp, builder, moduleTranslation);
229+
})
230+
.Case([&](TypeOffsetOp typeOffsetOp) {
231+
return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
232+
})
233+
.Default([&](Operation *op) {
234+
return op->emitError("Translation for operation '")
235+
<< op->getName() << "' is not implemented.";
236+
});
40237
}
41238

42239
/// Attaches module-level metadata for functions marked as kernels.
43240
LogicalResult
44241
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
45242
NamedAttribute attribute,
46243
LLVM::ModuleTranslation &moduleTranslation) const final {
47-
// Translation for ptr dialect operations to LLVM IR is currently
48-
// unimplemented.
49-
return op->emitError("Translation for ptr dialect operations to LLVM IR is "
50-
"not implemented.");
244+
// No special amendments needed for ptr dialect operations
245+
return success();
51246
}
52247
};
53248
} // namespace

mlir/test/Target/LLVMIR/ptr.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,78 @@ llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr) {
1414
llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr
1515
llvm.return
1616
}
17+
18+
// CHECK-LABEL: define ptr @ptr_add
19+
// CHECK-SAME: (ptr %[[PTR:.*]], i32 %[[OFF:.*]]) {
20+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
21+
// CHECK-NEXT: %[[RES0:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
22+
// CHECK-NEXT: %[[RES1:.*]] = getelementptr nusw i8, ptr %[[PTR]], i32 %[[OFF]]
23+
// CHECK-NEXT: %[[RES2:.*]] = getelementptr nuw i8, ptr %[[PTR]], i32 %[[OFF]]
24+
// CHECK-NEXT: %[[RES3:.*]] = getelementptr inbounds i8, ptr %[[PTR]], i32 %[[OFF]]
25+
// CHECK-NEXT: ret ptr %[[RES]]
26+
// CHECK-NEXT: }
27+
llvm.func @ptr_add(%ptr: !ptr.ptr<#llvm.address_space<0>>, %off: i32) -> !ptr.ptr<#llvm.address_space<0>> {
28+
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
29+
%res0 = ptr.ptr_add none %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
30+
%res1 = ptr.ptr_add nusw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
31+
%res2 = ptr.ptr_add nuw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
32+
%res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
33+
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
34+
}
35+
36+
// CHECK-LABEL: define { i32, i32, i32, i32 } @type_offset
37+
// CHECK-NEXT: ret { i32, i32, i32, i32 } { i32 8, i32 1, i32 2, i32 4 }
38+
llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<(i32, i32, i32, i32)> {
39+
%0 = ptr.type_offset f64 : i32
40+
%1 = ptr.type_offset i8 : i32
41+
%2 = ptr.type_offset i16 : i32
42+
%3 = ptr.type_offset i32 : i32
43+
%4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
44+
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
45+
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
46+
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
47+
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
48+
llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
49+
}
50+
51+
// CHECK-LABEL: define void @load_ops
52+
// CHECK-SAME: (ptr %[[PTR:.*]]) {
53+
// CHECK-NEXT: %[[V0:.*]] = load float, ptr %[[PTR]], align 4
54+
// CHECK-NEXT: %[[V1:.*]] = load volatile float, ptr %[[PTR]], align 4
55+
// CHECK-NEXT: %[[V2:.*]] = load float, ptr %[[PTR]], align 4, !nontemporal !{{.*}}
56+
// CHECK-NEXT: %[[V3:.*]] = load float, ptr %[[PTR]], align 4, !invariant.load !{{.*}}
57+
// CHECK-NEXT: %[[V4:.*]] = load float, ptr %[[PTR]], align 4, !invariant.group !{{.*}}
58+
// CHECK-NEXT: %[[V5:.*]] = load atomic i64, ptr %[[PTR]] monotonic, align 8
59+
// CHECK-NEXT: %[[V6:.*]] = load atomic volatile i32, ptr %[[PTR]] syncscope("workgroup") acquire, align 4, !nontemporal !{{.*}}
60+
// CHECK-NEXT: ret void
61+
// CHECK-NEXT: }
62+
llvm.func @load_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>) {
63+
%0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
64+
%1 = ptr.load volatile %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
65+
%2 = ptr.load %arg0 nontemporal : !ptr.ptr<#llvm.address_space<0>> -> f32
66+
%3 = ptr.load %arg0 invariant : !ptr.ptr<#llvm.address_space<0>> -> f32
67+
%4 = ptr.load %arg0 invariant_group : !ptr.ptr<#llvm.address_space<0>> -> f32
68+
%5 = ptr.load %arg0 atomic monotonic alignment = 8 : !ptr.ptr<#llvm.address_space<0>> -> i64
69+
%6 = ptr.load volatile %arg0 atomic syncscope("workgroup") acquire nontemporal alignment = 4 : !ptr.ptr<#llvm.address_space<0>> -> i32
70+
llvm.return
71+
}
72+
73+
// CHECK-LABEL: define void @store_ops
74+
// CHECK-SAME: (ptr %[[PTR:.*]], float %[[ARG1:.*]], i64 %[[ARG2:.*]], i32 %[[ARG3:.*]]) {
75+
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4
76+
// CHECK-NEXT: store volatile float %[[ARG1]], ptr %[[PTR]], align 4
77+
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !nontemporal !{{.*}}
78+
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !invariant.group !{{.*}}
79+
// CHECK-NEXT: store atomic i64 %[[ARG2]], ptr %[[PTR]] monotonic, align 8
80+
// CHECK-NEXT: store atomic volatile i32 %[[ARG3]], ptr %[[PTR]] syncscope("workgroup") release, align 4, !nontemporal !{{.*}}
81+
// CHECK-NEXT: ret void
82+
// CHECK-NEXT: }
83+
llvm.func @store_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>, %arg1: f32, %arg2: i64, %arg3: i32) {
84+
ptr.store %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
85+
ptr.store volatile %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
86+
ptr.store %arg1, %arg0 nontemporal : f32, !ptr.ptr<#llvm.address_space<0>>
87+
ptr.store %arg1, %arg0 invariant_group : f32, !ptr.ptr<#llvm.address_space<0>>
88+
ptr.store %arg2, %arg0 atomic monotonic alignment = 8 : i64, !ptr.ptr<#llvm.address_space<0>>
89+
ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#llvm.address_space<0>>
90+
llvm.return
91+
}

0 commit comments

Comments
 (0)