Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ class Ptr_Attr<string name, string attrMnemonic,
let mnemonic = attrMnemonic;
}

//===----------------------------------------------------------------------===//
// AddressAttr
//===----------------------------------------------------------------------===//

def Ptr_AddressAttr : Ptr_Attr<"Address", "address", [
DeclareAttrInterfaceMethods<TypedAttrInterface>
]> {
let summary = "Address attribute";
let description = [{
The `address` attribute represents a raw memory address.

Example:

```mlir
#ptr.address<0x1000> : !ptr.ptr<#ptr.generic_space>
```
}];
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type,
APIntParameter<"">:$value);
let builders = [
AttrBuilderWithInferredContext<(ins "PtrType":$type,
"const llvm::APInt &":$value), [{
return $_get(type.getContext(), type, value);
}]>
];
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// GenericSpaceAttr
//===----------------------------------------------------------------------===//
Expand All @@ -37,16 +65,42 @@ def Ptr_GenericSpaceAttr :
- Load and store operations are always valid, regardless of the type.
- Atomic operations are always valid, regardless of the type.
- Cast operations to `generic_space` are always valid.

Example:

```mlir
#ptr.generic_space
#ptr.generic_space : !ptr.ptr<#ptr.generic_space>
```
}];
let assemblyFormat = "";
}

//===----------------------------------------------------------------------===//
// NullAttr
//===----------------------------------------------------------------------===//

def Ptr_NullAttr : Ptr_Attr<"Null", "null", [
DeclareAttrInterfaceMethods<TypedAttrInterface>
]> {
let summary = "Null pointer attribute";
let description = [{
The `null` attribute represents a null pointer.

Example:

```mlir
#ptr.null
```
}];
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type);
let builders = [
AttrBuilderWithInferredContext<(ins "PtrType":$type), [{
return $_get(type.getContext(), type);
}]>
];
let assemblyFormat = "";
}

//===----------------------------------------------------------------------===//
// SpecAttr
//===----------------------------------------------------------------------===//
Expand All @@ -62,7 +116,7 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
- [Optional] index: bitwidth that should be used when performing index
computations for the type. Setting the field to `kOptionalSpecValue`, means
the field is optional.

Furthermore, the attribute will verify that all present values are divisible
by 8 (number of bits in a byte), and that `preferred` > `abi`.

Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/Dialect/Ptr/IR/PtrEnums.h"

namespace mlir {
namespace ptr {
class PtrType;
} // namespace ptr
} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"

Expand Down
39 changes: 32 additions & 7 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
/*cppType=*/"::mlir::ShapedType">;

// A ptr-like type, either scalar or shaped type with value semantics.
def Ptr_PtrLikeType :
def Ptr_PtrLikeType :
AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;

// An int-like type, either scalar or shaped type with value semantics.
Expand All @@ -57,6 +57,31 @@ def Ptr_Mask1DType :
def Ptr_Ptr1DType :
Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

def Ptr_ConstantOp : Pointer_Op<"constant", [
ConstantLike, Pure, AllTypesMatch<["value", "result"]>
]> {
let summary = "Pointer constant operation";
let description = [{
The `constant` operation produces a pointer constant. The attribute must be
a typed attribute of pointer type.

Example:

```mlir
// Create a null pointer
%null = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
```
}];
let arguments = (ins TypedAttrInterface:$value);
let results = (outs Ptr_PtrType:$result);
let assemblyFormat = "attr-dict $value";
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
Expand All @@ -81,7 +106,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
```mlir
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>

// Cast the `%ptr` to a memref without utilizing metadata.
%memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
```
Expand Down Expand Up @@ -361,13 +386,13 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
// Scalar base and offset
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32

// Shaped base with scalar offset
%ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32

// Scalar base with shaped offset
%x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>

// Both base and offset are shaped
%ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
```
Expand All @@ -382,7 +407,7 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
let hasFolder = 1;
let extraClassDeclaration = [{
/// `ViewLikeOp::getViewSource` method.
/// `ViewLikeOp::getViewSource` method.
Value getViewSource() { return getBase(); }

/// Returns the ptr type of the operation.
Expand Down Expand Up @@ -418,7 +443,7 @@ def Ptr_ScatterOp : Pointer_Op<"scatter", [
// Scatter values to multiple memory locations
ptr.scatter %value, %ptrs, %mask :
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>

// Scatter with alignment
ptr.scatter %value, %ptrs, %mask alignment = 8 :
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
Expand Down
7 changes: 4 additions & 3 deletions mlir/include/mlir/IR/DialectImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ struct FieldParser<

/// Parse any integer.
template <typename IntT>
struct FieldParser<IntT,
std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
struct FieldParser<IntT, std::enable_if_t<(std::is_integral<IntT>::value ||
std::is_same_v<IntT, llvm::APInt>),
IntT>> {
static FailureOr<IntT> parse(AsmParser &parser) {
IntT value = 0;
IntT value{};
Copy link

Copilot AI Sep 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default initialization IntT value{} will zero-initialize llvm::APInt, but APInt requires explicit bit width specification in its constructor. This will cause compilation errors when IntT is llvm::APInt.

Copilot uses AI. Check for mistakes.

if (parser.parseInteger(value))
return failure();
return value;
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ verifyAlignment(std::optional<int64_t> alignment,
return success();
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
Expand Down
52 changes: 52 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,55 @@ convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
return success();
}

/// Convert ptr.constant operation
static LogicalResult
convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
// Convert result type to LLVM type
llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
moduleTranslation.convertType(constantOp.getResult().getType()));
if (!resultType)
return constantOp.emitError("Expected a valid pointer type");

llvm::Value *result = nullptr;

TypedAttr value = constantOp.getValue();
if (auto nullAttr = dyn_cast<ptr::NullAttr>(value)) {
// Create a null pointer constant
result = llvm::ConstantPointerNull::get(resultType);
} else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
// Create an integer constant and convert it to pointer
llvm::APInt addressValue = addressAttr.getValue();

// Determine the integer type width based on the target's pointer size
llvm::DataLayout dataLayout =
moduleTranslation.getLLVMModule()->getDataLayout();
unsigned pointerSizeInBits =
dataLayout.getPointerSizeInBits(resultType->getAddressSpace());

// Extend or truncate the address value to match pointer size if needed
if (addressValue.getBitWidth() != pointerSizeInBits) {
if (addressValue.getBitWidth() > pointerSizeInBits) {
constantOp.emitWarning()
<< "Truncating address value to fit pointer size";
}
addressValue = addressValue.getBitWidth() < pointerSizeInBits
? addressValue.zext(pointerSizeInBits)
: addressValue.trunc(pointerSizeInBits);
}

// Create integer constant and convert to pointer
llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue);
result = builder.CreateIntToPtr(intValue, resultType);
} else {
return constantOp.emitError("Unsupported constant attribute type");
}

moduleTranslation.mapValue(constantOp.getResult(), result);
return success();
}

/// Implementation of the dialect interface that converts operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
Expand All @@ -314,6 +363,9 @@ class PtrDialectLLVMIRTranslationInterface
LLVM::ModuleTranslation &moduleTranslation) const final {

return llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](ConstantOp constantOp) {
return convertConstantOp(constantOp, builder, moduleTranslation);
})
.Case([&](PtrAddOp ptrAddOp) {
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
})
Expand Down
24 changes: 23 additions & 1 deletion mlir/test/Dialect/Ptr/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func.func @masked_store_ops_tensor(%value: tensor<8xi64>, %ptr: !ptr.ptr<#ptr.ge
}

/// Test operations with LLVM address space
func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
%mask: vector<4xi1>, %value: vector<4xf32>, %passthrough: vector<4xf32>) -> vector<4xf32> {
// Gather from shared memory (address space 3)
%0 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf32>
Expand Down Expand Up @@ -189,3 +189,25 @@ func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.gener
%res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}

/// Test constant operations with null pointer
func.func @constant_null_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>) {
%null_generic = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
%null_as1 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<1>>
return %null_generic, %null_as1 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>
}

/// Test constant operations with address values
func.func @constant_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>) {
%addr_0 = ptr.constant #ptr.address<0> : !ptr.ptr<#ptr.generic_space>
%addr_1000 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
%addr_deadbeef = ptr.constant #ptr.address<0xDEADBEEF> : !ptr.ptr<#llvm.address_space<3>>
return %addr_0, %addr_1000, %addr_deadbeef : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>
}

/// Test constant operations with large address values
func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>) {
%addr_max32 = ptr.constant #ptr.address<0xFFFFFFFF> : !ptr.ptr<#ptr.generic_space>
%addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
}
36 changes: 31 additions & 5 deletions mlir/test/Target/LLVMIR/ptr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<
%2 = ptr.type_offset i16 : i32
%3 = ptr.type_offset i32 : i32
%4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
}

Expand Down Expand Up @@ -194,7 +194,7 @@ llvm.func @scatter_ops_i64(%value: vector<8xi64>, %ptrs: vector<8x!ptr.ptr<#llvm
// CHECK-NEXT: call void @llvm.masked.store.v4f64.p3(<4 x double> %[[VALUE_F64]], ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]])
// CHECK-NEXT: ret void
// CHECK-NEXT: }
llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
%mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) {
// Test with shared memory address space (3) and f64 elements
%0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64>
Expand Down Expand Up @@ -255,3 +255,29 @@ llvm.func @llvm_ops_with_ptr_nvvm_values(%arg0: !llvm.ptr) {
llvm.store %1, %arg0 : !ptr.ptr<#nvvm.memory_space<global>>, !llvm.ptr
llvm.return
}

// CHECK-LABEL: define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() {
// CHECK-NEXT: ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) }
llvm.func @constant_address_op() ->
!llvm.struct<(!ptr.ptr<#llvm.address_space<0>>,
!ptr.ptr<#llvm.address_space<1>>,
!ptr.ptr<#llvm.address_space<2>>)> {
%0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
%1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
%2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>>
%3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
}

// Test gep folders.
// CHECK-LABEL: define ptr @ptr_add_cst() {
// CHECK-NEXT: ret ptr inttoptr (i64 42 to ptr)
llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
%off = llvm.mlir.constant(42 : i32) : i32
%ptr = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
}
Loading