Skip to content

Commit 46fa8d4

Browse files
committed
add translations for ptr ops
1 parent 11f4be0 commit 46fa8d4

File tree

3 files changed

+278
-10
lines changed

3 files changed

+278
-10
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def LLVM_NonLoadableTargetExtType : Type<
9595
// type that has size (not void, function, opaque struct type or target
9696
// extension type which does not support memory operations).
9797
def LLVM_LoadableType : Type<
98-
Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>,
99-
Neg<LLVM_NonLoadableTargetExtType.predicate>]>,
98+
Or<[CPred<"mlir::LLVM::isLoadableType($_self)">,
10099
LLVM_PointerElementTypeInterface.predicate]>,
101100
"LLVM type with size">;
102101

mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp

Lines changed: 202 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,193 @@
1616
#include "mlir/IR/BuiltinAttributes.h"
1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
19+
#include "llvm/ADT/TypeSwitch.h"
20+
#include "llvm/IR/IRBuilder.h"
21+
#include "llvm/IR/Instructions.h"
22+
#include "llvm/IR/Type.h"
23+
#include "llvm/IR/Value.h"
1924

2025
using namespace mlir;
2126
using namespace mlir::ptr;
2227

2328
namespace {
29+
30+
/// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
31+
static llvm::AtomicOrdering
32+
convertAtomicOrdering(ptr::AtomicOrdering ordering) {
33+
switch (ordering) {
34+
case ptr::AtomicOrdering::not_atomic:
35+
return llvm::AtomicOrdering::NotAtomic;
36+
case ptr::AtomicOrdering::unordered:
37+
return llvm::AtomicOrdering::Unordered;
38+
case ptr::AtomicOrdering::monotonic:
39+
return llvm::AtomicOrdering::Monotonic;
40+
case ptr::AtomicOrdering::acquire:
41+
return llvm::AtomicOrdering::Acquire;
42+
case ptr::AtomicOrdering::release:
43+
return llvm::AtomicOrdering::Release;
44+
case ptr::AtomicOrdering::acq_rel:
45+
return llvm::AtomicOrdering::AcquireRelease;
46+
case ptr::AtomicOrdering::seq_cst:
47+
return llvm::AtomicOrdering::SequentiallyConsistent;
48+
}
49+
llvm_unreachable("Unknown atomic ordering");
50+
}
51+
52+
/// Convert ptr.ptr_add operation
53+
static LogicalResult
54+
convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55+
LLVM::ModuleTranslation &moduleTranslation) {
56+
llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
57+
llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
58+
59+
if (!basePtr || !offset)
60+
return ptrAddOp.emitError("Failed to lookup operands");
61+
62+
// Create GEP instruction for pointer arithmetic
63+
llvm::GetElementPtrInst *gep = llvm::GetElementPtrInst::Create(
64+
builder.getInt8Ty(), basePtr, {offset}, "", builder.GetInsertBlock());
65+
66+
// Set the appropriate flags
67+
switch (ptrAddOp.getFlags()) {
68+
case ptr::PtrAddFlags::none:
69+
break;
70+
case ptr::PtrAddFlags::nusw:
71+
gep->setNoWrapFlags(llvm::GEPNoWrapFlags::noUnsignedSignedWrap());
72+
break;
73+
case ptr::PtrAddFlags::nuw:
74+
gep->setNoWrapFlags(llvm::GEPNoWrapFlags::noUnsignedWrap());
75+
break;
76+
case ptr::PtrAddFlags::inbounds:
77+
gep->setNoWrapFlags(llvm::GEPNoWrapFlags::inBounds());
78+
break;
79+
}
80+
81+
moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
82+
return success();
83+
}
84+
85+
/// Convert ptr.load operation
86+
static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
87+
LLVM::ModuleTranslation &moduleTranslation) {
88+
llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
89+
if (!ptr)
90+
return loadOp.emitError("Failed to lookup pointer operand");
91+
92+
// Convert result type to LLVM type
93+
llvm::Type *resultType =
94+
moduleTranslation.convertType(loadOp.getValue().getType());
95+
if (!resultType)
96+
return loadOp.emitError("Failed to convert result type");
97+
98+
// Create the load instruction.
99+
llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
100+
resultType, ptr, llvm::MaybeAlign(loadOp.getAlignment().value_or(0)),
101+
loadOp.getVolatile_());
102+
103+
// Set op flags and metadata.
104+
loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
105+
// Set sync scope if specified
106+
if (loadOp.getSyncscope().has_value()) {
107+
llvm::LLVMContext &ctx = builder.getContext();
108+
llvm::SyncScope::ID syncScope =
109+
ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
110+
loadInst->setSyncScopeID(syncScope);
111+
}
112+
113+
// Set metadata for nontemporal, invariant, and invariant_group
114+
if (loadOp.getNontemporal()) {
115+
llvm::MDNode *nontemporalMD =
116+
llvm::MDNode::get(builder.getContext(),
117+
llvm::ConstantAsMetadata::get(builder.getInt32(1)));
118+
loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
119+
}
120+
121+
if (loadOp.getInvariant()) {
122+
llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {});
123+
loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
124+
}
125+
126+
if (loadOp.getInvariantGroup()) {
127+
llvm::MDNode *invariantGroupMD =
128+
llvm::MDNode::get(builder.getContext(), {});
129+
loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
130+
invariantGroupMD);
131+
}
132+
133+
moduleTranslation.mapValue(loadOp.getResult(), loadInst);
134+
return success();
135+
}
136+
137+
/// Convert ptr.store operation
138+
static LogicalResult
139+
convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
140+
LLVM::ModuleTranslation &moduleTranslation) {
141+
llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
142+
llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
143+
144+
if (!value || !ptr)
145+
return storeOp.emitError("Failed to lookup operands");
146+
147+
// Create the store instruction.
148+
llvm::StoreInst *storeInst = builder.CreateAlignedStore(
149+
value, ptr, llvm::MaybeAlign(storeOp.getAlignment().value_or(0)),
150+
storeOp.getVolatile_());
151+
152+
// Set op flags and metadata.
153+
storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
154+
// Set sync scope if specified
155+
if (storeOp.getSyncscope().has_value()) {
156+
llvm::LLVMContext &ctx = builder.getContext();
157+
llvm::SyncScope::ID syncScope =
158+
ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
159+
storeInst->setSyncScopeID(syncScope);
160+
}
161+
162+
// Set metadata for nontemporal and invariant_group
163+
if (storeOp.getNontemporal()) {
164+
llvm::MDNode *nontemporalMD =
165+
llvm::MDNode::get(builder.getContext(),
166+
llvm::ConstantAsMetadata::get(builder.getInt32(1)));
167+
storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
168+
}
169+
170+
if (storeOp.getInvariantGroup()) {
171+
llvm::MDNode *invariantGroupMD =
172+
llvm::MDNode::get(builder.getContext(), {});
173+
storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
174+
invariantGroupMD);
175+
}
176+
177+
return success();
178+
}
179+
180+
/// Convert ptr.type_offset operation
181+
static LogicalResult
182+
convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
183+
LLVM::ModuleTranslation &moduleTranslation) {
184+
// Convert the element type to LLVM type
185+
llvm::Type *elementType =
186+
moduleTranslation.convertType(typeOffsetOp.getElementType());
187+
if (!elementType)
188+
return typeOffsetOp.emitError("Failed to convert the element type");
189+
190+
// Convert result type
191+
llvm::Type *resultType =
192+
moduleTranslation.convertType(typeOffsetOp.getResult().getType());
193+
if (!resultType)
194+
return typeOffsetOp.emitError("Failed to convert the result type");
195+
196+
// Use GEP with null pointer to compute type size/offset.
197+
llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
198+
llvm::Value *offsetPtr =
199+
builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
200+
llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
201+
202+
moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
203+
return success();
204+
}
205+
24206
/// Implementation of the dialect interface that converts operations belonging
25207
/// to the `ptr` dialect to LLVM IR.
26208
class PtrDialectLLVMIRTranslationInterface
@@ -33,21 +215,33 @@ class PtrDialectLLVMIRTranslationInterface
33215
LogicalResult
34216
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
35217
LLVM::ModuleTranslation &moduleTranslation) const final {
36-
// Translation for ptr dialect operations to LLVM IR is currently
37-
// unimplemented.
38-
return op->emitError("Translation for ptr dialect operations to LLVM IR is "
39-
"not implemented.");
218+
219+
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
220+
.Case<PtrAddOp>([&](PtrAddOp ptrAddOp) {
221+
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
222+
})
223+
.Case<LoadOp>([&](LoadOp loadOp) {
224+
return convertLoadOp(loadOp, builder, moduleTranslation);
225+
})
226+
.Case<StoreOp>([&](StoreOp storeOp) {
227+
return convertStoreOp(storeOp, builder, moduleTranslation);
228+
})
229+
.Case<TypeOffsetOp>([&](TypeOffsetOp typeOffsetOp) {
230+
return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
231+
})
232+
.Default([&](Operation *op) {
233+
return op->emitError("Translation for operation '")
234+
<< op->getName() << "' is not implemented.";
235+
});
40236
}
41237

42238
/// Attaches module-level metadata for functions marked as kernels.
43239
LogicalResult
44240
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
45241
NamedAttribute attribute,
46242
LLVM::ModuleTranslation &moduleTranslation) const final {
47-
// Translation for ptr dialect operations to LLVM IR is currently
48-
// unimplemented.
49-
return op->emitError("Translation for ptr dialect operations to LLVM IR is "
50-
"not implemented.");
243+
// No special amendments needed for ptr dialect operations
244+
return success();
51245
}
52246
};
53247
} // namespace

mlir/test/Target/LLVMIR/ptr.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,78 @@ llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr) {
1414
llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr
1515
llvm.return
1616
}
17+
18+
// CHECK-LABEL: define ptr @ptr_add
19+
// CHECK-SAME: (ptr %[[PTR:.*]], i32 %[[OFF:.*]]) {
20+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
21+
// CHECK-NEXT: %[[RES0:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
22+
// CHECK-NEXT: %[[RES1:.*]] = getelementptr nusw i8, ptr %[[PTR]], i32 %[[OFF]]
23+
// CHECK-NEXT: %[[RES2:.*]] = getelementptr nuw i8, ptr %[[PTR]], i32 %[[OFF]]
24+
// CHECK-NEXT: %[[RES3:.*]] = getelementptr inbounds i8, ptr %[[PTR]], i32 %[[OFF]]
25+
// CHECK-NEXT: ret ptr %[[RES]]
26+
// CHECK-NEXT: }
27+
llvm.func @ptr_add(%ptr: !ptr.ptr<#llvm.address_space<0>>, %off: i32) -> !ptr.ptr<#llvm.address_space<0>> {
28+
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
29+
%res0 = ptr.ptr_add none %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
30+
%res1 = ptr.ptr_add nusw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
31+
%res2 = ptr.ptr_add nuw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
32+
%res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
33+
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
34+
}
35+
36+
// CHECK-LABEL: define { i32, i32, i32, i32 } @type_offset
37+
// CHECK-NEXT: ret { i32, i32, i32, i32 } { i32 8, i32 1, i32 2, i32 4 }
38+
llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<(i32, i32, i32, i32)> {
39+
%0 = ptr.type_offset f64 : i32
40+
%1 = ptr.type_offset i8 : i32
41+
%2 = ptr.type_offset i16 : i32
42+
%3 = ptr.type_offset i32 : i32
43+
%4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
44+
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
45+
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
46+
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
47+
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
48+
llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
49+
}
50+
51+
// CHECK-LABEL: define void @load_ops
52+
// CHECK-SAME: (ptr %[[PTR:.*]]) {
53+
// CHECK-NEXT: %[[V0:.*]] = load float, ptr %[[PTR]], align 4
54+
// CHECK-NEXT: %[[V1:.*]] = load volatile float, ptr %[[PTR]], align 4
55+
// CHECK-NEXT: %[[V2:.*]] = load float, ptr %[[PTR]], align 4, !nontemporal !{{.*}}
56+
// CHECK-NEXT: %[[V3:.*]] = load float, ptr %[[PTR]], align 4, !invariant.load !{{.*}}
57+
// CHECK-NEXT: %[[V4:.*]] = load float, ptr %[[PTR]], align 4, !invariant.group !{{.*}}
58+
// CHECK-NEXT: %[[V5:.*]] = load atomic i64, ptr %[[PTR]] monotonic, align 8
59+
// CHECK-NEXT: %[[V6:.*]] = load atomic volatile i32, ptr %[[PTR]] syncscope("workgroup") acquire, align 4, !nontemporal !{{.*}}
60+
// CHECK-NEXT: ret void
61+
// CHECK-NEXT: }
62+
llvm.func @load_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>) {
63+
%0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
64+
%1 = ptr.load volatile %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
65+
%2 = ptr.load %arg0 nontemporal : !ptr.ptr<#llvm.address_space<0>> -> f32
66+
%3 = ptr.load %arg0 invariant : !ptr.ptr<#llvm.address_space<0>> -> f32
67+
%4 = ptr.load %arg0 invariant_group : !ptr.ptr<#llvm.address_space<0>> -> f32
68+
%5 = ptr.load %arg0 atomic monotonic alignment = 8 : !ptr.ptr<#llvm.address_space<0>> -> i64
69+
%6 = ptr.load volatile %arg0 atomic syncscope("workgroup") acquire nontemporal alignment = 4 : !ptr.ptr<#llvm.address_space<0>> -> i32
70+
llvm.return
71+
}
72+
73+
// CHECK-LABEL: define void @store_ops
74+
// CHECK-SAME: (ptr %[[PTR:.*]], float %[[ARG1:.*]], i64 %[[ARG2:.*]], i32 %[[ARG3:.*]]) {
75+
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4
76+
// CHECK-NEXT: store volatile float %[[ARG1]], ptr %[[PTR]], align 4
77+
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !nontemporal !{{.*}}
78+
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !invariant.group !{{.*}}
79+
// CHECK-NEXT: store atomic i64 %[[ARG2]], ptr %[[PTR]] monotonic, align 8
80+
// CHECK-NEXT: store atomic volatile i32 %[[ARG3]], ptr %[[PTR]] syncscope("workgroup") release, align 4, !nontemporal !{{.*}}
81+
// CHECK-NEXT: ret void
82+
// CHECK-NEXT: }
83+
llvm.func @store_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>, %arg1: f32, %arg2: i64, %arg3: i32) {
84+
ptr.store %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
85+
ptr.store volatile %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
86+
ptr.store %arg1, %arg0 nontemporal : f32, !ptr.ptr<#llvm.address_space<0>>
87+
ptr.store %arg1, %arg0 invariant_group : f32, !ptr.ptr<#llvm.address_space<0>>
88+
ptr.store %arg2, %arg0 atomic monotonic alignment = 8 : i64, !ptr.ptr<#llvm.address_space<0>>
89+
ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#llvm.address_space<0>>
90+
llvm.return
91+
}

0 commit comments

Comments
 (0)