Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,14 @@ def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [
let cppNamespace = "::mlir::ptr";
}

//===----------------------------------------------------------------------===//
// Ptr diff flags enum properties.
//===----------------------------------------------------------------------===//

def Ptr_PtrDiffFlags : I8BitEnum<"PtrDiffFlags", "Pointer difference flags", [
I8BitEnumCase<"none", 0>, I8BitEnumCase<"nuw", 1>, I8BitEnumCase<"nsw", 2>
]> {
let cppNamespace = "::mlir::ptr";
}

#endif // PTR_ENUMS
57 changes: 57 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,63 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
}

//===----------------------------------------------------------------------===//
// PtrDiffOp
//===----------------------------------------------------------------------===//

def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
Pure, AllTypesMatch<["lhs", "rhs"]>, SameOperandsAndResultShape
]> {
let summary = "Pointer difference operation";
let description = [{
The `ptr_diff` operation computes the difference between two pointers,
returning an integer or index value representing the number of bytes
between them.

The operation supports both scalar and shaped types with value semantics:
- When both operands are scalar: produces a single difference value
- When both are shaped: performs element-wise subtraction,
shapes must be the same

The operation also supports the following flags:
- `none`: No flags are set.
- `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow
(that is: the result would be negative), the result is a poison value.
- `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the
result is a poison value.

NOTE: The pointer difference is calculated using an integer type specified
by the data layout. The final result will be sign-extended or truncated to
fit the result type as necessary.

Example:

```mlir
// Scalar pointers
%diff = ptr.ptr_diff %p1, %p2 : !ptr.ptr<#ptr.generic_space> -> i64

// Shaped pointers
%diffs = ptr.ptr_diff nsw %ptrs1, %ptrs2 :
vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
```
}];
let arguments = (ins
Ptr_PtrLikeType:$lhs, Ptr_PtrLikeType:$rhs,
DefaultValuedProp<EnumProp<Ptr_PtrDiffFlags>, "PtrDiffFlags::none">:$flags
);
let results = (outs Ptr_IntLikeType:$result);
let assemblyFormat = [{
($flags^)? $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
}];
let extraClassDeclaration = [{
/// Returns the operand's ptr type.
ptr::PtrType getPtrType();
/// Returns the result's underlying int type.
Type getIntType();
}];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringExtras.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an accidental include.

#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
Expand Down Expand Up @@ -391,6 +392,39 @@ LogicalResult PtrAddOp::inferReturnTypes(
return success();
}

//===----------------------------------------------------------------------===//
// PtrDiffOp
//===----------------------------------------------------------------------===//

LogicalResult PtrDiffOp::verify() {
// If the operands are not shaped early exit.
if (!isa<ShapedType>(getLhs().getType()))
return success();

// Just check the container type matches, `SameOperandsAndResultShape` handles
// the actual shape.
if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) {
return emitError() << "expected the result to have the same container "
"type as the operands when operands are shaped";
}

return success();
}

ptr::PtrType PtrDiffOp::getPtrType() {
Type lhsType = getLhs().getType();
if (auto shapedType = dyn_cast<ShapedType>(lhsType))
return cast<ptr::PtrType>(shapedType.getElementType());
return cast<ptr::PtrType>(lhsType);
}

Type PtrDiffOp::getIntType() {
Type resultType = getResult().getType();
if (auto shapedType = dyn_cast<ShapedType>(resultType))
return shapedType.getElementType();
return resultType;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get a folder for the subtraction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, in general I need to improve the folders and canonicalizers of the entire dialect, so how about a new PR for that?

//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 39 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,42 @@ translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
return success();
}

/// Translate ptr.ptr_diff operation operation to LLVM IR.
static LogicalResult
translatePtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());

if (!lhs || !rhs)
return ptrDiffOp.emitError("Failed to lookup operands");

// Translate result type to LLVM type
llvm::Type *resultType =
moduleTranslation.convertType(ptrDiffOp.getResult().getType());
if (!resultType)
return ptrDiffOp.emitError("Failed to translate result type");

PtrDiffFlags flags = ptrDiffOp.getFlags();

// Convert both pointers to integers using ptrtoaddr, and compute the
// difference: lhs - rhs
llvm::Value *llLhs = builder.CreatePtrToAddr(lhs);
llvm::Value *llRhs = builder.CreatePtrToAddr(rhs);
llvm::Value *result = builder.CreateSub(
llLhs, llRhs, /*Name=*/"",
/*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
/*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);

// Convert the difference to the expected result type by truncating or
// extending.
if (result->getType() != resultType)
result = builder.CreateIntCast(result, resultType, /*isSigned=*/true);

moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
return success();
}

/// Implementation of the dialect interface that translates operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
Expand All @@ -371,6 +407,9 @@ class PtrDialectLLVMIRTranslationInterface
.Case([&](PtrAddOp ptrAddOp) {
return translatePtrAddOp(ptrAddOp, builder, moduleTranslation);
})
.Case([&](PtrDiffOp ptrDiffOp) {
return translatePtrDiffOp(ptrDiffOp, builder, moduleTranslation);
})
.Case([&](LoadOp loadOp) {
return translateLoadOp(loadOp, builder, moduleTranslation);
})
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Ptr/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,11 @@ func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>,
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}

// -----

func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> vector<8xi64> {
// expected-error@+1 {{the result to have the same container type as the operands when operands are shaped}}
%res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
return %res : vector<8xi64>
}
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Ptr/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,31 @@ func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.p
%addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
}

/// Test ptr_diff operations with scalar pointers
func.func @ptr_diff_scalar_ops(%ptr1: !ptr.ptr<#ptr.generic_space>, %ptr2: !ptr.ptr<#ptr.generic_space>) -> (i64, index, i32) {
%diff_i64 = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i64
%diff_index = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> index
%diff_i32 = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i32
return %diff_i64, %diff_index, %diff_i32 : i64, index, i32
}

/// Test ptr_diff operations with vector pointers
func.func @ptr_diff_vector_ops(%ptrs1: vector<4x!ptr.ptr<#ptr.generic_space>>, %ptrs2: vector<4x!ptr.ptr<#ptr.generic_space>>) -> (vector<4xi64>, vector<4xindex>) {
%diff_i64 = ptr.ptr_diff none %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
%diff_index = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xindex>
return %diff_i64, %diff_index : vector<4xi64>, vector<4xindex>
}

/// Test ptr_diff operations with tensor pointers
func.func @ptr_diff_tensor_ops(%ptrs1: tensor<8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> (tensor<8xi64>, tensor<8xi32>) {
%diff_i64 = ptr.ptr_diff nsw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi64>
%diff_i32 = ptr.ptr_diff nsw | nuw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32>
return %diff_i64, %diff_i32 : tensor<8xi64>, tensor<8xi32>
}

/// Test ptr_diff operations with 2D tensor pointers
func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<4x8x!ptr.ptr<#ptr.generic_space>>) -> tensor<4x8xi64> {
%diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
return %diff : tensor<4x8xi64>
}
96 changes: 96 additions & 0 deletions mlir/test/Target/LLVMIR/ptr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,99 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
}

// CHECK-LABEL: define i64 @ptr_diff_scalar
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
%diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
llvm.return %diff : i64
}

// CHECK-LABEL: define i32 @ptr_diff_scalar_i32
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32
// CHECK-NEXT: ret i32 %[[TRUNC]]
// CHECK-NEXT: }
llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i32 {
%diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i32
llvm.return %diff : i32
}

// CHECK-LABEL: define <4 x i64> @ptr_diff_vector
// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) {
// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64>
// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64>
// CHECK-NEXT: %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret <4 x i64> %[[DIFF]]
// CHECK-NEXT: }
llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<4x!ptr.ptr<#llvm.address_space<0>>>) -> vector<4xi64> {
%diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xi64>
llvm.return %diffs : vector<4xi64>
}

// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32
// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) {
// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64>
// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64>
// CHECK-NEXT: %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32>
// CHECK-NEXT: ret <8 x i32> %[[TRUNC]]
// CHECK-NEXT: }
llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
%diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
llvm.return %diffs : vector<8xi32>
}

// CHECK-LABEL: define i64 @ptr_diff_with_constants() {
// CHECK-NEXT: ret i64 4096
// CHECK-NEXT: }
llvm.func @ptr_diff_with_constants() -> i64 {
%ptr1 = ptr.constant #ptr.address<0x2000> : !ptr.ptr<#llvm.address_space<0>>
%ptr2 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<0>>
%diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
llvm.return %diff : i64
}

// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
%diff = ptr.ptr_diff nsw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
llvm.return %diff : i64
}

// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
%diff = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
llvm.return %diff : i64
}

// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
llvm.func @ptr_diff_with_flags_nsw_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
%diff = ptr.ptr_diff nsw | nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
llvm.return %diff : i64
}