-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][ptr] Add ptr.ptr_diff
operation
#157354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Fabian Mora (fabianmcg) ChangesThi patch introduces the
This patch also adds translation to LLVM IR hooks for the Example: llvm.func @<!-- -->ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
%diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
llvm.return %diffs : vector<8xi32>
} Translation to LLVM IR: define <8 x i32> @<!-- -->ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) {
%3 = ptrtoint <8 x ptr> %0 to <8 x i64>
%4 = ptrtoint <8 x ptr> %1 to <8 x i64>
%5 = sub <8 x i64> %3, %4
%6 = trunc <8 x i64> %5 to <8 x i32>
ret <8 x i32> %6
} Full diff: https://github.com/llvm/llvm-project/pull/157354.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
index c169f48e573d0..c97bd04d32896 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
@@ -79,4 +79,14 @@ def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [
let cppNamespace = "::mlir::ptr";
}
+//===----------------------------------------------------------------------===//
+// Ptr diff flags enum properties.
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffFlags : I8BitEnum<"PtrDiffFlags", "Pointer difference flags", [
+ I8BitEnumCase<"none", 0>, I8BitEnumCase<"nuw", 1>, I8BitEnumCase<"nsw", 2>
+ ]> {
+ let cppNamespace = "::mlir::ptr";
+}
+
#endif // PTR_ENUMS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 468a3004d5c62..85902fdf7159e 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -415,6 +415,63 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
}
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
+ Pure, AllTypesMatch<["lhs", "rhs"]>, SameOperandsAndResultShape
+ ]> {
+ let summary = "Pointer difference operation";
+ let description = [{
+ The `ptr_diff` operation computes the difference between two pointers,
+ returning an integer or index value representing the number of bytes
+ between them. This difference is always computed using signed arithmetic.
+
+ The operation supports both scalar and shaped types with value semantics:
+ - When both operands are scalar: produces a single difference value
+ - When both are shaped: performs element-wise subtraction,
+ shapes must be the same
+
+ The operation also supports the following flags:
+ - `none`: No flags are set.
+ - `nuw`: No Unsigned Wrap, if the addition causes an unsigned overflow,
+ the result is a poison value.
+ - `nsw`: No Signed Wrap, if the addition causes a signed overflow, the
+ result is a poison value.
+
+ NOTE: The pointer difference is calculated using an integer type specified
+ by the data layout. The final result will be sign-extended or truncated to
+ fit the result type as necessary.
+
+ Example:
+
+ ```mlir
+ // Scalar pointers
+ %diff = ptr.ptr_diff %p1, %p2 : !ptr.ptr<#ptr.generic_space> -> i64
+
+ // Shaped pointers
+ %diffs = ptr.ptr_diff nsw %ptrs1, %ptrs2 :
+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+ ```
+ }];
+ let arguments = (ins
+ Ptr_PtrLikeType:$lhs, Ptr_PtrLikeType:$rhs,
+ DefaultValuedProp<EnumProp<Ptr_PtrDiffFlags>, "PtrDiffFlags::none">:$flags
+ );
+ let results = (outs Ptr_IntLikeType:$result);
+ let assemblyFormat = [{
+ ($flags^)? $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
+ }];
+ let extraClassDeclaration = [{
+ /// Returns the operand's ptr type.
+ ptr::PtrType getPtrType();
+ /// Returns the result's underlying int type.
+ Type getIntType();
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index f0209af8a1ca3..51f25f755a8a6 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -391,6 +392,39 @@ LogicalResult PtrAddOp::inferReturnTypes(
return success();
}
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PtrDiffOp::verify() {
+ // If the operands are not shaped early exit.
+ if (!isa<ShapedType>(getLhs().getType()))
+ return success();
+
+ // Just check the container type matches, `SameOperandsAndResultShape` handles
+ // the actual shape.
+ if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) {
+ return emitError() << "expected the result to have the same container "
+ "type as the operands when operands are shaped";
+ }
+
+ return success();
+}
+
+ptr::PtrType PtrDiffOp::getPtrType() {
+ Type lhsType = getLhs().getType();
+ if (auto shapedType = dyn_cast<ShapedType>(lhsType))
+ return cast<ptr::PtrType>(shapedType.getElementType());
+ return cast<ptr::PtrType>(lhsType);
+}
+
+Type PtrDiffOp::getIntType() {
+ Type resultType = getResult().getType();
+ if (auto shapedType = dyn_cast<ShapedType>(resultType))
+ return shapedType.getElementType();
+ return resultType;
+}
+
//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 11b921de21596..556d8762ade6b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -349,6 +349,40 @@ convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
return success();
}
+/// Convert ptr.ptr_diff operation
+static LogicalResult
+convertPtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
+ llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());
+
+ if (!lhs || !rhs)
+ return ptrDiffOp.emitError("Failed to lookup operands");
+
+ // Convert result type to LLVM type
+ llvm::Type *resultType =
+ moduleTranslation.convertType(ptrDiffOp.getResult().getType());
+ if (!resultType)
+ return ptrDiffOp.emitError("Failed to convert result type");
+
+ PtrDiffFlags flags = ptrDiffOp.getFlags();
+
+ // Convert both pointers to integers using ptrtoaddr, and compute the
+ // difference: lhs - rhs
+ llvm::Value *result = builder.CreateSub(
+ builder.CreatePtrToAddr(lhs), builder.CreatePtrToAddr(rhs), /*Name=*/"",
+ /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
+ /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);
+
+ // Convert the difference to the expected result type by truncating or
+ // extending.
+ if (result->getType() != resultType)
+ result = builder.CreateIntCast(result, resultType, /*isSigned=*/true);
+
+ moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
+ return success();
+}
+
/// Implementation of the dialect interface that converts operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
@@ -369,6 +403,9 @@ class PtrDialectLLVMIRTranslationInterface
.Case([&](PtrAddOp ptrAddOp) {
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
})
+ .Case([&](PtrDiffOp ptrDiffOp) {
+ return convertPtrDiffOp(ptrDiffOp, builder, moduleTranslation);
+ })
.Case([&](LoadOp loadOp) {
return convertLoadOp(loadOp, builder, moduleTranslation);
})
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index cc1eeb3cb5744..83e1c880650c5 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -70,3 +70,11 @@ func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>,
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}
+
+// -----
+
+func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> vector<8xi64> {
+ // expected-error@+1 {{the result to have the same container type as the operands when operands are shaped}}
+ %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
+ return %res : vector<8xi64>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 7b2254185f57c..0a906ad559e21 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -211,3 +211,31 @@ func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.p
%addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
}
+
+/// Test ptr_diff operations with scalar pointers
+func.func @ptr_diff_scalar_ops(%ptr1: !ptr.ptr<#ptr.generic_space>, %ptr2: !ptr.ptr<#ptr.generic_space>) -> (i64, index, i32) {
+ %diff_i64 = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i64
+ %diff_index = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> index
+ %diff_i32 = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i32
+ return %diff_i64, %diff_index, %diff_i32 : i64, index, i32
+}
+
+/// Test ptr_diff operations with vector pointers
+func.func @ptr_diff_vector_ops(%ptrs1: vector<4x!ptr.ptr<#ptr.generic_space>>, %ptrs2: vector<4x!ptr.ptr<#ptr.generic_space>>) -> (vector<4xi64>, vector<4xindex>) {
+ %diff_i64 = ptr.ptr_diff none %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+ %diff_index = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xindex>
+ return %diff_i64, %diff_index : vector<4xi64>, vector<4xindex>
+}
+
+/// Test ptr_diff operations with tensor pointers
+func.func @ptr_diff_tensor_ops(%ptrs1: tensor<8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> (tensor<8xi64>, tensor<8xi32>) {
+ %diff_i64 = ptr.ptr_diff nsw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi64>
+ %diff_i32 = ptr.ptr_diff nsw | nuw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32>
+ return %diff_i64, %diff_i32 : tensor<8xi64>, tensor<8xi32>
+}
+
+/// Test ptr_diff operations with 2D tensor pointers
+func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<4x8x!ptr.ptr<#ptr.generic_space>>) -> tensor<4x8xi64> {
+ %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
+ return %diff : tensor<4x8xi64>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 17954c9eaa8c6..ce5a13f680394 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -259,3 +259,99 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
}
+
+// CHECK-LABEL: define i64 @ptr_diff_scalar
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i32 @ptr_diff_scalar_i32
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32
+// CHECK-NEXT: ret i32 %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i32 {
+ %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i32
+ llvm.return %diff : i32
+}
+
+// CHECK-LABEL: define <4 x i64> @ptr_diff_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64>
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64>
+// CHECK-NEXT: %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret <4 x i64> %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<4x!ptr.ptr<#llvm.address_space<0>>>) -> vector<4xi64> {
+ %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xi64>
+ llvm.return %diffs : vector<4xi64>
+}
+
+// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32
+// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64>
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64>
+// CHECK-NEXT: %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32>
+// CHECK-NEXT: ret <8 x i32> %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
+ %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
+ llvm.return %diffs : vector<8xi32>
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_constants() {
+// CHECK-NEXT: ret i64 4096
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_constants() -> i64 {
+ %ptr1 = ptr.constant #ptr.address<0x2000> : !ptr.ptr<#llvm.address_space<0>>
+ %ptr2 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<0>>
+ %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff nsw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff nsw | nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
|
@llvm/pr-subscribers-mlir-llvm Author: Fabian Mora (fabianmcg) ChangesThi patch introduces the
This patch also adds translation to LLVM IR hooks for the Example: llvm.func @<!-- -->ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
%diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
llvm.return %diffs : vector<8xi32>
} Translation to LLVM IR: define <8 x i32> @<!-- -->ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) {
%3 = ptrtoint <8 x ptr> %0 to <8 x i64>
%4 = ptrtoint <8 x ptr> %1 to <8 x i64>
%5 = sub <8 x i64> %3, %4
%6 = trunc <8 x i64> %5 to <8 x i32>
ret <8 x i32> %6
} Full diff: https://github.com/llvm/llvm-project/pull/157354.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
index c169f48e573d0..c97bd04d32896 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
@@ -79,4 +79,14 @@ def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [
let cppNamespace = "::mlir::ptr";
}
+//===----------------------------------------------------------------------===//
+// Ptr diff flags enum properties.
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffFlags : I8BitEnum<"PtrDiffFlags", "Pointer difference flags", [
+ I8BitEnumCase<"none", 0>, I8BitEnumCase<"nuw", 1>, I8BitEnumCase<"nsw", 2>
+ ]> {
+ let cppNamespace = "::mlir::ptr";
+}
+
#endif // PTR_ENUMS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 468a3004d5c62..85902fdf7159e 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -415,6 +415,63 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
}
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
+ Pure, AllTypesMatch<["lhs", "rhs"]>, SameOperandsAndResultShape
+ ]> {
+ let summary = "Pointer difference operation";
+ let description = [{
+ The `ptr_diff` operation computes the difference between two pointers,
+ returning an integer or index value representing the number of bytes
+ between them. This difference is always computed using signed arithmetic.
+
+ The operation supports both scalar and shaped types with value semantics:
+ - When both operands are scalar: produces a single difference value
+ - When both are shaped: performs element-wise subtraction,
+ shapes must be the same
+
+ The operation also supports the following flags:
+ - `none`: No flags are set.
+ - `nuw`: No Unsigned Wrap, if the addition causes an unsigned overflow,
+ the result is a poison value.
+ - `nsw`: No Signed Wrap, if the addition causes a signed overflow, the
+ result is a poison value.
+
+ NOTE: The pointer difference is calculated using an integer type specified
+ by the data layout. The final result will be sign-extended or truncated to
+ fit the result type as necessary.
+
+ Example:
+
+ ```mlir
+ // Scalar pointers
+ %diff = ptr.ptr_diff %p1, %p2 : !ptr.ptr<#ptr.generic_space> -> i64
+
+ // Shaped pointers
+ %diffs = ptr.ptr_diff nsw %ptrs1, %ptrs2 :
+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+ ```
+ }];
+ let arguments = (ins
+ Ptr_PtrLikeType:$lhs, Ptr_PtrLikeType:$rhs,
+ DefaultValuedProp<EnumProp<Ptr_PtrDiffFlags>, "PtrDiffFlags::none">:$flags
+ );
+ let results = (outs Ptr_IntLikeType:$result);
+ let assemblyFormat = [{
+ ($flags^)? $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
+ }];
+ let extraClassDeclaration = [{
+ /// Returns the operand's ptr type.
+ ptr::PtrType getPtrType();
+ /// Returns the result's underlying int type.
+ Type getIntType();
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index f0209af8a1ca3..51f25f755a8a6 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -391,6 +392,39 @@ LogicalResult PtrAddOp::inferReturnTypes(
return success();
}
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PtrDiffOp::verify() {
+ // If the operands are not shaped early exit.
+ if (!isa<ShapedType>(getLhs().getType()))
+ return success();
+
+ // Just check the container type matches, `SameOperandsAndResultShape` handles
+ // the actual shape.
+ if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) {
+ return emitError() << "expected the result to have the same container "
+ "type as the operands when operands are shaped";
+ }
+
+ return success();
+}
+
+ptr::PtrType PtrDiffOp::getPtrType() {
+ Type lhsType = getLhs().getType();
+ if (auto shapedType = dyn_cast<ShapedType>(lhsType))
+ return cast<ptr::PtrType>(shapedType.getElementType());
+ return cast<ptr::PtrType>(lhsType);
+}
+
+Type PtrDiffOp::getIntType() {
+ Type resultType = getResult().getType();
+ if (auto shapedType = dyn_cast<ShapedType>(resultType))
+ return shapedType.getElementType();
+ return resultType;
+}
+
//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 11b921de21596..556d8762ade6b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -349,6 +349,40 @@ convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
return success();
}
+/// Convert ptr.ptr_diff operation
+static LogicalResult
+convertPtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
+ llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());
+
+ if (!lhs || !rhs)
+ return ptrDiffOp.emitError("Failed to lookup operands");
+
+ // Convert result type to LLVM type
+ llvm::Type *resultType =
+ moduleTranslation.convertType(ptrDiffOp.getResult().getType());
+ if (!resultType)
+ return ptrDiffOp.emitError("Failed to convert result type");
+
+ PtrDiffFlags flags = ptrDiffOp.getFlags();
+
+ // Convert both pointers to integers using ptrtoaddr, and compute the
+ // difference: lhs - rhs
+ llvm::Value *result = builder.CreateSub(
+ builder.CreatePtrToAddr(lhs), builder.CreatePtrToAddr(rhs), /*Name=*/"",
+ /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
+ /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);
+
+ // Convert the difference to the expected result type by truncating or
+ // extending.
+ if (result->getType() != resultType)
+ result = builder.CreateIntCast(result, resultType, /*isSigned=*/true);
+
+ moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
+ return success();
+}
+
/// Implementation of the dialect interface that converts operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
@@ -369,6 +403,9 @@ class PtrDialectLLVMIRTranslationInterface
.Case([&](PtrAddOp ptrAddOp) {
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
})
+ .Case([&](PtrDiffOp ptrDiffOp) {
+ return convertPtrDiffOp(ptrDiffOp, builder, moduleTranslation);
+ })
.Case([&](LoadOp loadOp) {
return convertLoadOp(loadOp, builder, moduleTranslation);
})
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index cc1eeb3cb5744..83e1c880650c5 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -70,3 +70,11 @@ func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>,
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}
+
+// -----
+
+func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> vector<8xi64> {
+ // expected-error@+1 {{the result to have the same container type as the operands when operands are shaped}}
+ %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
+ return %res : vector<8xi64>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 7b2254185f57c..0a906ad559e21 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -211,3 +211,31 @@ func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.p
%addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
}
+
+/// Test ptr_diff operations with scalar pointers
+func.func @ptr_diff_scalar_ops(%ptr1: !ptr.ptr<#ptr.generic_space>, %ptr2: !ptr.ptr<#ptr.generic_space>) -> (i64, index, i32) {
+ %diff_i64 = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i64
+ %diff_index = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> index
+ %diff_i32 = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i32
+ return %diff_i64, %diff_index, %diff_i32 : i64, index, i32
+}
+
+/// Test ptr_diff operations with vector pointers
+func.func @ptr_diff_vector_ops(%ptrs1: vector<4x!ptr.ptr<#ptr.generic_space>>, %ptrs2: vector<4x!ptr.ptr<#ptr.generic_space>>) -> (vector<4xi64>, vector<4xindex>) {
+ %diff_i64 = ptr.ptr_diff none %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+ %diff_index = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xindex>
+ return %diff_i64, %diff_index : vector<4xi64>, vector<4xindex>
+}
+
+/// Test ptr_diff operations with tensor pointers
+func.func @ptr_diff_tensor_ops(%ptrs1: tensor<8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> (tensor<8xi64>, tensor<8xi32>) {
+ %diff_i64 = ptr.ptr_diff nsw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi64>
+ %diff_i32 = ptr.ptr_diff nsw | nuw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32>
+ return %diff_i64, %diff_i32 : tensor<8xi64>, tensor<8xi32>
+}
+
+/// Test ptr_diff operations with 2D tensor pointers
+func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<4x8x!ptr.ptr<#ptr.generic_space>>) -> tensor<4x8xi64> {
+ %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
+ return %diff : tensor<4x8xi64>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 17954c9eaa8c6..ce5a13f680394 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -259,3 +259,99 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
}
+
+// CHECK-LABEL: define i64 @ptr_diff_scalar
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i32 @ptr_diff_scalar_i32
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32
+// CHECK-NEXT: ret i32 %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i32 {
+ %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i32
+ llvm.return %diff : i32
+}
+
+// CHECK-LABEL: define <4 x i64> @ptr_diff_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64>
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64>
+// CHECK-NEXT: %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret <4 x i64> %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<4x!ptr.ptr<#llvm.address_space<0>>>) -> vector<4xi64> {
+ %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xi64>
+ llvm.return %diffs : vector<4xi64>
+}
+
+// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32
+// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64>
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64>
+// CHECK-NEXT: %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32>
+// CHECK-NEXT: ret <8 x i32> %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
+ %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
+ llvm.return %diffs : vector<8xi32>
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_constants() {
+// CHECK-NEXT: ret i64 4096
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_constants() -> i64 {
+ %ptr1 = ptr.constant #ptr.address<0x2000> : !ptr.ptr<#llvm.address_space<0>>
+ %ptr2 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<0>>
+ %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff nsw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT: ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+ %diff = ptr.ptr_diff nsw | nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ llvm.return %diff : i64
+}
|
|
f467d25
to
5e4b984
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces the ptr.ptr_diff
operation to the Ptr dialect for computing pointer differences with support for various flags and shaped types. The operation returns the byte difference between two pointers as an integer or index value.
Key changes include:
- Adding the
ptr.ptr_diff
operation with overflow flags (nsw, nuw) - Supporting both scalar and vector/tensor shaped types with element-wise semantics
- Implementing LLVM IR translation that converts pointers to integers and performs subtraction
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td | Defines PtrDiffFlags enum for overflow behavior flags |
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | Defines the ptr_diff operation with its syntax and semantics |
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | Implements verification logic and helper methods for ptr_diff |
mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp | Adds LLVM IR translation for ptr_diff operation |
mlir/test/Dialect/Ptr/ops.mlir | Tests valid ptr_diff operations with various types and flags |
mlir/test/Dialect/Ptr/invalid.mlir | Tests error handling for invalid ptr_diff usage |
mlir/test/Target/LLVMIR/ptr.mlir | Tests LLVM IR generation for ptr_diff operations |
362ce4b
to
966d174
Compare
5e4b984
to
a144c1b
Compare
Thi patch introduces the `ptr.ptr_diff` operation for computing pointer differences. The semantics of the operation are given by: ``` The `ptr_diff` operation computes the difference between two pointers, returning an integer or index value representing the number of bytes between them. This difference is always computed using signed arithmetic. The operation supports both scalar and shaped types with value semantics: - When both operands are scalar: produces a single difference value - When both are shaped: performs element-wise subtraction, shapes must be the same The operation also supports the following flags: - `none`: No flags are set. - `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow, the result is a poison value. - `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the result is a poison value. NOTE: The pointer difference is calculated using an integer type specified by the data layout. The final result will be sign-extended or truncated to fit the result type as necessary. ``` This patch also adds translation to LLVM IR hooks for the `ptr_diff` op. This translation uses the `ptrtoaddr` builder to compute only index bits difference. Example: ```mlir llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> { %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32> llvm.return %diffs : vector<8xi32> } ``` Translation to LLVM IR: ```llvm define <8 x i32> @ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) { %3 = ptrtoint <8 x ptr> %0 to <8 x i64> %4 = ptrtoint <8 x ptr> %1 to <8 x i64> %5 = sub <8 x i64> %3, %4 %6 = trunc <8 x i64> %5 to <8 x i32> ret <8 x i32> %6 } ```
a144c1b
to
c25dd35
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the behavior for an address space mismatch?
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Interfaces/DataLayoutInterfaces.h" | ||
#include "mlir/Transforms/InliningUtils.h" | ||
#include "llvm/ADT/StringExtras.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like an accidental include.
The lhs and rhs types must match, so there’s never an address space mismatch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, makes sense then.
LGTM!
let description = [{ | ||
The `ptr_diff` operation computes the difference between two pointers, | ||
returning an integer or index value representing the number of bytes | ||
between them. This difference is always computed using signed arithmetic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
between them. This difference is always computed using signed arithmetic. | |
between them. This difference is always computed using signed arithmetic. |
What does it mean? Can you provide an example of "signed" vs "unsigned" arithmetic here?
Can a pointer itself be negative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Negative pointers no, but take: (ptr + 16) - (ptr + 32), the result will be negative.
I took as rationale the existence of ptrdiff_t
in C https://en.cppreference.com/w/c/types/ptrdiff_t.html
I'll clarify the comment, because it's not clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be mindful that C is a language with signed type, while we're operating in a signless IR here.
The fact that the result can be negative is an obvious aspect of the operation being a subtraction, it has not much to do with "signed arithmetic".
I'd look at arith.subi for inspiration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the comment because I already have the phrase:
The final result will be sign-extended or truncated to
fit the result type as necessary.
in the description.
Co-authored-by: Mehdi Amini <[email protected]>
return shapedType.getElementType(); | ||
return resultType; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get a folder for the subtraction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, in general I need to improve the folders and canonicalizers of the entire dialect, so how about a new PR for that?
Thi patch introduces the `ptr.ptr_diff` operation for computing pointer differences. The semantics of the operation are given by: ``` The `ptr_diff` operation computes the difference between two pointers, returning an integer or index value representing the number of bytes between them. The operation supports both scalar and shaped types with value semantics: - When both operands are scalar: produces a single difference value - When both are shaped: performs element-wise subtraction, shapes must be the same The operation also supports the following flags: - `none`: No flags are set. - `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow, the result is a poison value. - `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the result is a poison value. NOTE: The pointer difference is calculated using an integer type specified by the data layout. The final result will be sign-extended or truncated to fit the result type as necessary. ``` This patch also adds translation to LLVM IR hooks for the `ptr_diff` op. This translation uses the `ptrtoaddr` builder to compute only index bits difference. Example: ```mlir llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> { %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32> llvm.return %diffs : vector<8xi32> } ``` Translation to LLVM IR: ```llvm define <8 x i32> @ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) { %3 = ptrtoint <8 x ptr> %0 to <8 x i64> %4 = ptrtoint <8 x ptr> %1 to <8 x i64> %5 = sub <8 x i64> %3, %4 %6 = trunc <8 x i64> %5 to <8 x i32> ret <8 x i32> %6 } ``` --------- Co-authored-by: Mehdi Amini <[email protected]>
Thi patch introduces the
ptr.ptr_diff
operation for computing pointer differences. The semantics of the operation are given by:This patch also adds translation to LLVM IR hooks for the
ptr_diff
op. This translation uses theptrtoaddr
builder to compute only index bits difference.Example:
Translation to LLVM IR: