16
16
#include " mlir/IR/BuiltinAttributes.h"
17
17
#include " mlir/IR/Operation.h"
18
18
#include " mlir/Target/LLVMIR/ModuleTranslation.h"
19
+ #include " llvm/ADT/TypeSwitch.h"
20
+ #include " llvm/IR/IRBuilder.h"
21
+ #include " llvm/IR/Instructions.h"
22
+ #include " llvm/IR/Type.h"
23
+ #include " llvm/IR/Value.h"
19
24
20
25
using namespace mlir ;
21
26
using namespace mlir ::ptr;
22
27
23
28
namespace {
29
+
30
+ // / Converts ptr::AtomicOrdering to llvm::AtomicOrdering
31
+ static llvm::AtomicOrdering
32
+ convertAtomicOrdering (ptr::AtomicOrdering ordering) {
33
+ switch (ordering) {
34
+ case ptr::AtomicOrdering::not_atomic:
35
+ return llvm::AtomicOrdering::NotAtomic;
36
+ case ptr::AtomicOrdering::unordered:
37
+ return llvm::AtomicOrdering::Unordered;
38
+ case ptr::AtomicOrdering::monotonic:
39
+ return llvm::AtomicOrdering::Monotonic;
40
+ case ptr::AtomicOrdering::acquire:
41
+ return llvm::AtomicOrdering::Acquire;
42
+ case ptr::AtomicOrdering::release:
43
+ return llvm::AtomicOrdering::Release;
44
+ case ptr::AtomicOrdering::acq_rel:
45
+ return llvm::AtomicOrdering::AcquireRelease;
46
+ case ptr::AtomicOrdering::seq_cst:
47
+ return llvm::AtomicOrdering::SequentiallyConsistent;
48
+ }
49
+ llvm_unreachable (" Unknown atomic ordering" );
50
+ }
51
+
52
+ // / Convert ptr.ptr_add operation
53
+ static LogicalResult
54
+ convertPtrAddOp (PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55
+ LLVM::ModuleTranslation &moduleTranslation) {
56
+ llvm::Value *basePtr = moduleTranslation.lookupValue (ptrAddOp.getBase ());
57
+ llvm::Value *offset = moduleTranslation.lookupValue (ptrAddOp.getOffset ());
58
+
59
+ if (!basePtr || !offset)
60
+ return ptrAddOp.emitError (" Failed to lookup operands" );
61
+
62
+ // Create the GEP flags
63
+ llvm::GEPNoWrapFlags gepFlags;
64
+ switch (ptrAddOp.getFlags ()) {
65
+ case ptr::PtrAddFlags::none:
66
+ break ;
67
+ case ptr::PtrAddFlags::nusw:
68
+ gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap ();
69
+ break ;
70
+ case ptr::PtrAddFlags::nuw:
71
+ gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap ();
72
+ break ;
73
+ case ptr::PtrAddFlags::inbounds:
74
+ gepFlags = llvm::GEPNoWrapFlags::inBounds ();
75
+ break ;
76
+ }
77
+
78
+ // Create GEP instruction for pointer arithmetic
79
+ llvm::Value *gep =
80
+ builder.CreateGEP (builder.getInt8Ty (), basePtr, {offset}, " " , gepFlags);
81
+
82
+ moduleTranslation.mapValue (ptrAddOp.getResult (), gep);
83
+ return success ();
84
+ }
85
+
86
+ // / Convert ptr.load operation
87
+ static LogicalResult convertLoadOp (LoadOp loadOp, llvm::IRBuilderBase &builder,
88
+ LLVM::ModuleTranslation &moduleTranslation) {
89
+ llvm::Value *ptr = moduleTranslation.lookupValue (loadOp.getPtr ());
90
+ if (!ptr)
91
+ return loadOp.emitError (" Failed to lookup pointer operand" );
92
+
93
+ // Convert result type to LLVM type
94
+ llvm::Type *resultType =
95
+ moduleTranslation.convertType (loadOp.getValue ().getType ());
96
+ if (!resultType)
97
+ return loadOp.emitError (" Failed to convert result type" );
98
+
99
+ // Create the load instruction.
100
+ llvm::MaybeAlign alignment (loadOp.getAlignment ().value_or (0 ));
101
+ llvm::LoadInst *loadInst = builder.CreateAlignedLoad (
102
+ resultType, ptr, alignment, loadOp.getVolatile_ ());
103
+
104
+ // Set op flags and metadata.
105
+ loadInst->setAtomic (convertAtomicOrdering (loadOp.getOrdering ()));
106
+ // Set sync scope if specified
107
+ if (loadOp.getSyncscope ().has_value ()) {
108
+ llvm::LLVMContext &ctx = builder.getContext ();
109
+ llvm::SyncScope::ID syncScope =
110
+ ctx.getOrInsertSyncScopeID (loadOp.getSyncscope ().value ());
111
+ loadInst->setSyncScopeID (syncScope);
112
+ }
113
+
114
+ // Set metadata for nontemporal, invariant, and invariant_group
115
+ if (loadOp.getNontemporal ()) {
116
+ llvm::MDNode *nontemporalMD =
117
+ llvm::MDNode::get (builder.getContext (),
118
+ llvm::ConstantAsMetadata::get (builder.getInt32 (1 )));
119
+ loadInst->setMetadata (llvm::LLVMContext::MD_nontemporal, nontemporalMD);
120
+ }
121
+
122
+ if (loadOp.getInvariant ()) {
123
+ llvm::MDNode *invariantMD = llvm::MDNode::get (builder.getContext (), {});
124
+ loadInst->setMetadata (llvm::LLVMContext::MD_invariant_load, invariantMD);
125
+ }
126
+
127
+ if (loadOp.getInvariantGroup ()) {
128
+ llvm::MDNode *invariantGroupMD =
129
+ llvm::MDNode::get (builder.getContext (), {});
130
+ loadInst->setMetadata (llvm::LLVMContext::MD_invariant_group,
131
+ invariantGroupMD);
132
+ }
133
+
134
+ moduleTranslation.mapValue (loadOp.getResult (), loadInst);
135
+ return success ();
136
+ }
137
+
138
+ // / Convert ptr.store operation
139
+ static LogicalResult
140
+ convertStoreOp (StoreOp storeOp, llvm::IRBuilderBase &builder,
141
+ LLVM::ModuleTranslation &moduleTranslation) {
142
+ llvm::Value *value = moduleTranslation.lookupValue (storeOp.getValue ());
143
+ llvm::Value *ptr = moduleTranslation.lookupValue (storeOp.getPtr ());
144
+
145
+ if (!value || !ptr)
146
+ return storeOp.emitError (" Failed to lookup operands" );
147
+
148
+ // Create the store instruction.
149
+ llvm::MaybeAlign alignment (storeOp.getAlignment ().value_or (0 ));
150
+ llvm::StoreInst *storeInst =
151
+ builder.CreateAlignedStore (value, ptr, alignment, storeOp.getVolatile_ ());
152
+
153
+ // Set op flags and metadata.
154
+ storeInst->setAtomic (convertAtomicOrdering (storeOp.getOrdering ()));
155
+ // Set sync scope if specified
156
+ if (storeOp.getSyncscope ().has_value ()) {
157
+ llvm::LLVMContext &ctx = builder.getContext ();
158
+ llvm::SyncScope::ID syncScope =
159
+ ctx.getOrInsertSyncScopeID (storeOp.getSyncscope ().value ());
160
+ storeInst->setSyncScopeID (syncScope);
161
+ }
162
+
163
+ // Set metadata for nontemporal and invariant_group
164
+ if (storeOp.getNontemporal ()) {
165
+ llvm::MDNode *nontemporalMD =
166
+ llvm::MDNode::get (builder.getContext (),
167
+ llvm::ConstantAsMetadata::get (builder.getInt32 (1 )));
168
+ storeInst->setMetadata (llvm::LLVMContext::MD_nontemporal, nontemporalMD);
169
+ }
170
+
171
+ if (storeOp.getInvariantGroup ()) {
172
+ llvm::MDNode *invariantGroupMD =
173
+ llvm::MDNode::get (builder.getContext (), {});
174
+ storeInst->setMetadata (llvm::LLVMContext::MD_invariant_group,
175
+ invariantGroupMD);
176
+ }
177
+
178
+ return success ();
179
+ }
180
+
181
+ // / Convert ptr.type_offset operation
182
+ static LogicalResult
183
+ convertTypeOffsetOp (TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
184
+ LLVM::ModuleTranslation &moduleTranslation) {
185
+ // Convert the element type to LLVM type
186
+ llvm::Type *elementType =
187
+ moduleTranslation.convertType (typeOffsetOp.getElementType ());
188
+ if (!elementType)
189
+ return typeOffsetOp.emitError (" Failed to convert the element type" );
190
+
191
+ // Convert result type
192
+ llvm::Type *resultType =
193
+ moduleTranslation.convertType (typeOffsetOp.getResult ().getType ());
194
+ if (!resultType)
195
+ return typeOffsetOp.emitError (" Failed to convert the result type" );
196
+
197
+ // Use GEP with null pointer to compute type size/offset.
198
+ llvm::Value *nullPtr = llvm::Constant::getNullValue (builder.getPtrTy (0 ));
199
+ llvm::Value *offsetPtr =
200
+ builder.CreateGEP (elementType, nullPtr, {builder.getInt32 (1 )});
201
+ llvm::Value *offset = builder.CreatePtrToInt (offsetPtr, resultType);
202
+
203
+ moduleTranslation.mapValue (typeOffsetOp.getResult (), offset);
204
+ return success ();
205
+ }
206
+
24
207
// / Implementation of the dialect interface that converts operations belonging
25
208
// / to the `ptr` dialect to LLVM IR.
26
209
class PtrDialectLLVMIRTranslationInterface
@@ -33,21 +216,33 @@ class PtrDialectLLVMIRTranslationInterface
33
216
LogicalResult
34
217
convertOperation (Operation *op, llvm::IRBuilderBase &builder,
35
218
LLVM::ModuleTranslation &moduleTranslation) const final {
36
- // Translation for ptr dialect operations to LLVM IR is currently
37
- // unimplemented.
38
- return op->emitError (" Translation for ptr dialect operations to LLVM IR is "
39
- " not implemented." );
219
+
220
+ return llvm::TypeSwitch<Operation *, LogicalResult>(op)
221
+ .Case ([&](PtrAddOp ptrAddOp) {
222
+ return convertPtrAddOp (ptrAddOp, builder, moduleTranslation);
223
+ })
224
+ .Case ([&](LoadOp loadOp) {
225
+ return convertLoadOp (loadOp, builder, moduleTranslation);
226
+ })
227
+ .Case ([&](StoreOp storeOp) {
228
+ return convertStoreOp (storeOp, builder, moduleTranslation);
229
+ })
230
+ .Case ([&](TypeOffsetOp typeOffsetOp) {
231
+ return convertTypeOffsetOp (typeOffsetOp, builder, moduleTranslation);
232
+ })
233
+ .Default ([&](Operation *op) {
234
+ return op->emitError (" Translation for operation '" )
235
+ << op->getName () << " ' is not implemented." ;
236
+ });
40
237
}
41
238
42
239
// / Attaches module-level metadata for functions marked as kernels.
43
240
LogicalResult
44
241
amendOperation (Operation *op, ArrayRef<llvm::Instruction *> instructions,
45
242
NamedAttribute attribute,
46
243
LLVM::ModuleTranslation &moduleTranslation) const final {
47
- // Translation for ptr dialect operations to LLVM IR is currently
48
- // unimplemented.
49
- return op->emitError (" Translation for ptr dialect operations to LLVM IR is "
50
- " not implemented." );
244
+ // No special amendments needed for ptr dialect operations
245
+ return success ();
51
246
}
52
247
};
53
248
} // namespace
0 commit comments