Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def LLVM_NonLoadableTargetExtType : Type<
// type that has size (not void, function, opaque struct type or target
// extension type which does not support memory operations).
def LLVM_LoadableType : Type<
Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>,
Neg<LLVM_NonLoadableTargetExtType.predicate>]>,
Or<[CPred<"mlir::LLVM::isLoadableType($_self)">,
LLVM_PointerElementTypeInterface.predicate]>,
"LLVM type with size">;

Expand Down
211 changes: 203 additions & 8 deletions mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,194 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"

using namespace mlir;
using namespace mlir::ptr;

namespace {

/// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
static llvm::AtomicOrdering
convertAtomicOrdering(ptr::AtomicOrdering ordering) {
switch (ordering) {
case ptr::AtomicOrdering::not_atomic:
return llvm::AtomicOrdering::NotAtomic;
case ptr::AtomicOrdering::unordered:
return llvm::AtomicOrdering::Unordered;
case ptr::AtomicOrdering::monotonic:
return llvm::AtomicOrdering::Monotonic;
case ptr::AtomicOrdering::acquire:
return llvm::AtomicOrdering::Acquire;
case ptr::AtomicOrdering::release:
return llvm::AtomicOrdering::Release;
case ptr::AtomicOrdering::acq_rel:
return llvm::AtomicOrdering::AcquireRelease;
case ptr::AtomicOrdering::seq_cst:
return llvm::AtomicOrdering::SequentiallyConsistent;
}
llvm_unreachable("Unknown atomic ordering");
}

/// Convert ptr.ptr_add operation
static LogicalResult
convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());

if (!basePtr || !offset)
return ptrAddOp.emitError("Failed to lookup operands");

// Create the GEP flags
llvm::GEPNoWrapFlags gepFlags;
switch (ptrAddOp.getFlags()) {
case ptr::PtrAddFlags::none:
break;
case ptr::PtrAddFlags::nusw:
gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap();
break;
case ptr::PtrAddFlags::nuw:
gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap();
break;
case ptr::PtrAddFlags::inbounds:
gepFlags = llvm::GEPNoWrapFlags::inBounds();
break;
}

// Create GEP instruction for pointer arithmetic
llvm::Value *gep =
builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset}, "", gepFlags);

moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
return success();
}

/// Convert ptr.load operation
static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
if (!ptr)
return loadOp.emitError("Failed to lookup pointer operand");

// Convert result type to LLVM type
llvm::Type *resultType =
moduleTranslation.convertType(loadOp.getValue().getType());
if (!resultType)
return loadOp.emitError("Failed to convert result type");

// Create the load instruction.
llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
resultType, ptr, alignment, loadOp.getVolatile_());

// Set op flags and metadata.
loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
// Set sync scope if specified
if (loadOp.getSyncscope().has_value()) {
llvm::LLVMContext &ctx = builder.getContext();
llvm::SyncScope::ID syncScope =
ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
loadInst->setSyncScopeID(syncScope);
}

// Set metadata for nontemporal, invariant, and invariant_group
if (loadOp.getNontemporal()) {
llvm::MDNode *nontemporalMD =
llvm::MDNode::get(builder.getContext(),
llvm::ConstantAsMetadata::get(builder.getInt32(1)));
loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
}

if (loadOp.getInvariant()) {
llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {});
loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
}

if (loadOp.getInvariantGroup()) {
llvm::MDNode *invariantGroupMD =
llvm::MDNode::get(builder.getContext(), {});
loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
invariantGroupMD);
}

moduleTranslation.mapValue(loadOp.getResult(), loadInst);
return success();
}

/// Convert ptr.store operation
static LogicalResult
convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());

if (!value || !ptr)
return storeOp.emitError("Failed to lookup operands");

// Create the store instruction.
llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0));
llvm::StoreInst *storeInst =
builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());

// Set op flags and metadata.
storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
// Set sync scope if specified
if (storeOp.getSyncscope().has_value()) {
llvm::LLVMContext &ctx = builder.getContext();
llvm::SyncScope::ID syncScope =
ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
storeInst->setSyncScopeID(syncScope);
}

// Set metadata for nontemporal and invariant_group
if (storeOp.getNontemporal()) {
llvm::MDNode *nontemporalMD =
llvm::MDNode::get(builder.getContext(),
llvm::ConstantAsMetadata::get(builder.getInt32(1)));
storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
}

if (storeOp.getInvariantGroup()) {
llvm::MDNode *invariantGroupMD =
llvm::MDNode::get(builder.getContext(), {});
storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
invariantGroupMD);
}

return success();
}

/// Convert ptr.type_offset operation
static LogicalResult
convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
// Convert the element type to LLVM type
llvm::Type *elementType =
moduleTranslation.convertType(typeOffsetOp.getElementType());
if (!elementType)
return typeOffsetOp.emitError("Failed to convert the element type");

// Convert result type
llvm::Type *resultType =
moduleTranslation.convertType(typeOffsetOp.getResult().getType());
if (!resultType)
return typeOffsetOp.emitError("Failed to convert the result type");

// Use GEP with null pointer to compute type size/offset.
llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
llvm::Value *offsetPtr =
builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);

moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
return success();
}

/// Implementation of the dialect interface that converts operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
Expand All @@ -33,21 +216,33 @@ class PtrDialectLLVMIRTranslationInterface
LogicalResult
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const final {
// Translation for ptr dialect operations to LLVM IR is currently
// unimplemented.
return op->emitError("Translation for ptr dialect operations to LLVM IR is "
"not implemented.");

return llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](PtrAddOp ptrAddOp) {
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
})
.Case([&](LoadOp loadOp) {
return convertLoadOp(loadOp, builder, moduleTranslation);
})
.Case([&](StoreOp storeOp) {
return convertStoreOp(storeOp, builder, moduleTranslation);
})
.Case([&](TypeOffsetOp typeOffsetOp) {
return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
})
.Default([&](Operation *op) {
return op->emitError("Translation for operation '")
<< op->getName() << "' is not implemented.";
});
}

/// Attaches module-level metadata for functions marked as kernels.
LogicalResult
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
// Translation for ptr dialect operations to LLVM IR is currently
// unimplemented.
return op->emitError("Translation for ptr dialect operations to LLVM IR is "
"not implemented.");
// No special amendments needed for ptr dialect operations
return success();
}
};
} // namespace
Expand Down
75 changes: 75 additions & 0 deletions mlir/test/Target/LLVMIR/ptr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,78 @@ llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr) {
llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr
llvm.return
}

// CHECK-LABEL: define ptr @ptr_add
// CHECK-SAME: (ptr %[[PTR:.*]], i32 %[[OFF:.*]]) {
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
// CHECK-NEXT: %[[RES0:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
// CHECK-NEXT: %[[RES1:.*]] = getelementptr nusw i8, ptr %[[PTR]], i32 %[[OFF]]
// CHECK-NEXT: %[[RES2:.*]] = getelementptr nuw i8, ptr %[[PTR]], i32 %[[OFF]]
// CHECK-NEXT: %[[RES3:.*]] = getelementptr inbounds i8, ptr %[[PTR]], i32 %[[OFF]]
// CHECK-NEXT: ret ptr %[[RES]]
// CHECK-NEXT: }
llvm.func @ptr_add(%ptr: !ptr.ptr<#llvm.address_space<0>>, %off: i32) -> !ptr.ptr<#llvm.address_space<0>> {
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
%res0 = ptr.ptr_add none %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
%res1 = ptr.ptr_add nusw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
%res2 = ptr.ptr_add nuw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
%res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
}

// CHECK-LABEL: define { i32, i32, i32, i32 } @type_offset
// CHECK-NEXT: ret { i32, i32, i32, i32 } { i32 8, i32 1, i32 2, i32 4 }
llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<(i32, i32, i32, i32)> {
%0 = ptr.type_offset f64 : i32
%1 = ptr.type_offset i8 : i32
%2 = ptr.type_offset i16 : i32
%3 = ptr.type_offset i32 : i32
%4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
}

// CHECK-LABEL: define void @load_ops
// CHECK-SAME: (ptr %[[PTR:.*]]) {
// CHECK-NEXT: %[[V0:.*]] = load float, ptr %[[PTR]], align 4
// CHECK-NEXT: %[[V1:.*]] = load volatile float, ptr %[[PTR]], align 4
// CHECK-NEXT: %[[V2:.*]] = load float, ptr %[[PTR]], align 4, !nontemporal !{{.*}}
// CHECK-NEXT: %[[V3:.*]] = load float, ptr %[[PTR]], align 4, !invariant.load !{{.*}}
// CHECK-NEXT: %[[V4:.*]] = load float, ptr %[[PTR]], align 4, !invariant.group !{{.*}}
// CHECK-NEXT: %[[V5:.*]] = load atomic i64, ptr %[[PTR]] monotonic, align 8
// CHECK-NEXT: %[[V6:.*]] = load atomic volatile i32, ptr %[[PTR]] syncscope("workgroup") acquire, align 4, !nontemporal !{{.*}}
// CHECK-NEXT: ret void
// CHECK-NEXT: }
llvm.func @load_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>) {
%0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
%1 = ptr.load volatile %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
%2 = ptr.load %arg0 nontemporal : !ptr.ptr<#llvm.address_space<0>> -> f32
%3 = ptr.load %arg0 invariant : !ptr.ptr<#llvm.address_space<0>> -> f32
%4 = ptr.load %arg0 invariant_group : !ptr.ptr<#llvm.address_space<0>> -> f32
%5 = ptr.load %arg0 atomic monotonic alignment = 8 : !ptr.ptr<#llvm.address_space<0>> -> i64
%6 = ptr.load volatile %arg0 atomic syncscope("workgroup") acquire nontemporal alignment = 4 : !ptr.ptr<#llvm.address_space<0>> -> i32
llvm.return
}

// CHECK-LABEL: define void @store_ops
// CHECK-SAME: (ptr %[[PTR:.*]], float %[[ARG1:.*]], i64 %[[ARG2:.*]], i32 %[[ARG3:.*]]) {
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4
// CHECK-NEXT: store volatile float %[[ARG1]], ptr %[[PTR]], align 4
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !nontemporal !{{.*}}
// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !invariant.group !{{.*}}
// CHECK-NEXT: store atomic i64 %[[ARG2]], ptr %[[PTR]] monotonic, align 8
// CHECK-NEXT: store atomic volatile i32 %[[ARG3]], ptr %[[PTR]] syncscope("workgroup") release, align 4, !nontemporal !{{.*}}
// CHECK-NEXT: ret void
// CHECK-NEXT: }
llvm.func @store_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>, %arg1: f32, %arg2: i64, %arg3: i32) {
ptr.store %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
ptr.store volatile %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
ptr.store %arg1, %arg0 nontemporal : f32, !ptr.ptr<#llvm.address_space<0>>
ptr.store %arg1, %arg0 invariant_group : f32, !ptr.ptr<#llvm.address_space<0>>
ptr.store %arg2, %arg0 atomic monotonic alignment = 8 : i64, !ptr.ptr<#llvm.address_space<0>>
ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#llvm.address_space<0>>
llvm.return
}
Loading