Skip to content

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Sep 7, 2025

Thi patch introduces the ptr.ptr_diff operation for computing pointer differences. The semantics of the operation are given by:

The `ptr_diff` operation computes the difference between two pointers,
returning an integer or index value representing the number of bytes
between them.

The operation supports both scalar and shaped types with value semantics:
- When both operands are scalar: produces a single difference value
- When both are shaped: performs element-wise subtraction,
  shapes must be the same

The operation also supports the following flags:
- `none`: No flags are set.
- `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow,
  the result is a poison value.
- `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the
  result is a poison value.

NOTE: The pointer difference is calculated using an integer type specified
by the data layout. The final result will be sign-extended or truncated to
fit the result type as necessary.

This patch also adds translation to LLVM IR hooks for the ptr_diff op. This translation uses the ptrtoaddr builder to compute only index bits difference.

Example:

llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
  llvm.return %diffs : vector<8xi32>
}

Translation to LLVM IR:

define <8 x i32> @ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) {
  %3 = ptrtoint <8 x ptr> %0 to <8 x i64>
  %4 = ptrtoint <8 x ptr> %1 to <8 x i64>
  %5 = sub <8 x i64> %3, %4
  %6 = trunc <8 x i64> %5 to <8 x i32>
  ret <8 x i32> %6
}

@llvmbot
Copy link
Member

llvmbot commented Sep 7, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

Changes

Thi patch introduces the ptr.ptr_diff operation for computing pointer differences. The semantics of the operation are given by:

The `ptr_diff` operation computes the difference between two pointers,
returning an integer or index value representing the number of bytes
between them. This difference is always computed using signed arithmetic.

The operation supports both scalar and shaped types with value semantics:
- When both operands are scalar: produces a single difference value
- When both are shaped: performs element-wise subtraction,
  shapes must be the same

The operation also supports the following flags:
- `none`: No flags are set.
- `nuw`: No Unsigned Wrap, if the addition causes an unsigned overflow,
  the result is a poison value.
- `nsw`: No Signed Wrap, if the addition causes a signed overflow, the
  result is a poison value.

NOTE: The pointer difference is calculated using an integer type specified
by the data layout. The final result will be sign-extended or truncated to
fit the result type as necessary.

This patch also adds translation to LLVM IR hooks for the ptr_diff op. This translation uses the ptrtoaddr builder to compute only index bits difference.

Example:

llvm.func @<!-- -->ptr_diff_vector_i32(%ptrs1: vector&lt;8x!ptr.ptr&lt;#llvm.address_space&lt;0&gt;&gt;&gt;, %ptrs2: vector&lt;8x!ptr.ptr&lt;#llvm.address_space&lt;0&gt;&gt;&gt;) -&gt; vector&lt;8xi32&gt; {
  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector&lt;8x!ptr.ptr&lt;#llvm.address_space&lt;0&gt;&gt;&gt; -&gt; vector&lt;8xi32&gt;
  llvm.return %diffs : vector&lt;8xi32&gt;
}

Translation to LLVM IR:

define &lt;8 x i32&gt; @<!-- -->ptr_diff_vector_i32(&lt;8 x ptr&gt; %0, &lt;8 x ptr&gt; %1) {
  %3 = ptrtoint &lt;8 x ptr&gt; %0 to &lt;8 x i64&gt;
  %4 = ptrtoint &lt;8 x ptr&gt; %1 to &lt;8 x i64&gt;
  %5 = sub &lt;8 x i64&gt; %3, %4
  %6 = trunc &lt;8 x i64&gt; %5 to &lt;8 x i32&gt;
  ret &lt;8 x i32&gt; %6
}

Full diff: https://github.com/llvm/llvm-project/pull/157354.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td (+10)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+57)
  • (modified) mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp (+34)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp (+37)
  • (modified) mlir/test/Dialect/Ptr/invalid.mlir (+8)
  • (modified) mlir/test/Dialect/Ptr/ops.mlir (+28)
  • (modified) mlir/test/Target/LLVMIR/ptr.mlir (+96)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
index c169f48e573d0..c97bd04d32896 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
@@ -79,4 +79,14 @@ def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [
   let cppNamespace = "::mlir::ptr";
 }
 
+//===----------------------------------------------------------------------===//
+// Ptr diff flags enum properties.
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffFlags : I8BitEnum<"PtrDiffFlags", "Pointer difference flags", [
+    I8BitEnumCase<"none", 0>, I8BitEnumCase<"nuw", 1>, I8BitEnumCase<"nsw", 2>
+  ]> {
+  let cppNamespace = "::mlir::ptr";
+}
+
 #endif // PTR_ENUMS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 468a3004d5c62..85902fdf7159e 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -415,6 +415,63 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
+    Pure, AllTypesMatch<["lhs", "rhs"]>, SameOperandsAndResultShape
+  ]> {
+  let summary = "Pointer difference operation";
+  let description = [{
+    The `ptr_diff` operation computes the difference between two pointers,
+    returning an integer or index value representing the number of bytes
+    between them. This difference is always computed using signed arithmetic.
+
+    The operation supports both scalar and shaped types with value semantics:
+    - When both operands are scalar: produces a single difference value
+    - When both are shaped: performs element-wise subtraction,
+      shapes must be the same
+
+    The operation also supports the following flags:
+    - `none`: No flags are set.
+    - `nuw`: No Unsigned Wrap, if the addition causes an unsigned overflow,
+      the result is a poison value.
+    - `nsw`: No Signed Wrap, if the addition causes a signed overflow, the
+      result is a poison value.
+
+    NOTE: The pointer difference is calculated using an integer type specified
+    by the data layout. The final result will be sign-extended or truncated to
+    fit the result type as necessary.
+
+    Example:
+
+    ```mlir
+    // Scalar pointers
+    %diff = ptr.ptr_diff %p1, %p2 : !ptr.ptr<#ptr.generic_space> -> i64
+
+    // Shaped pointers
+    %diffs = ptr.ptr_diff nsw %ptrs1, %ptrs2 :
+      vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+    ```
+  }];
+  let arguments = (ins
+    Ptr_PtrLikeType:$lhs, Ptr_PtrLikeType:$rhs,
+    DefaultValuedProp<EnumProp<Ptr_PtrDiffFlags>, "PtrDiffFlags::none">:$flags
+  );
+  let results = (outs Ptr_IntLikeType:$result);
+  let assemblyFormat = [{
+    ($flags^)? $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
+  }];
+  let extraClassDeclaration = [{
+    /// Returns the operand's ptr type.
+    ptr::PtrType getPtrType();
+    /// Returns the result's underlying int type.
+    Type getIntType();
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index f0209af8a1ca3..51f25f755a8a6 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -391,6 +392,39 @@ LogicalResult PtrAddOp::inferReturnTypes(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PtrDiffOp::verify() {
+  // If the operands are not shaped early exit.
+  if (!isa<ShapedType>(getLhs().getType()))
+    return success();
+
+  // Just check the container type matches, `SameOperandsAndResultShape` handles
+  // the actual shape.
+  if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) {
+    return emitError() << "expected the result to have the same container "
+                          "type as the operands when operands are shaped";
+  }
+
+  return success();
+}
+
+ptr::PtrType PtrDiffOp::getPtrType() {
+  Type lhsType = getLhs().getType();
+  if (auto shapedType = dyn_cast<ShapedType>(lhsType))
+    return cast<ptr::PtrType>(shapedType.getElementType());
+  return cast<ptr::PtrType>(lhsType);
+}
+
+Type PtrDiffOp::getIntType() {
+  Type resultType = getResult().getType();
+  if (auto shapedType = dyn_cast<ShapedType>(resultType))
+    return shapedType.getElementType();
+  return resultType;
+}
+
 //===----------------------------------------------------------------------===//
 // ToPtrOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 11b921de21596..556d8762ade6b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -349,6 +349,40 @@ convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+/// Convert ptr.ptr_diff operation
+static LogicalResult
+convertPtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
+                 LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
+  llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());
+
+  if (!lhs || !rhs)
+    return ptrDiffOp.emitError("Failed to lookup operands");
+
+  // Convert result type to LLVM type
+  llvm::Type *resultType =
+      moduleTranslation.convertType(ptrDiffOp.getResult().getType());
+  if (!resultType)
+    return ptrDiffOp.emitError("Failed to convert result type");
+
+  PtrDiffFlags flags = ptrDiffOp.getFlags();
+
+  // Convert both pointers to integers using ptrtoaddr, and compute the
+  // difference: lhs - rhs
+  llvm::Value *result = builder.CreateSub(
+      builder.CreatePtrToAddr(lhs), builder.CreatePtrToAddr(rhs), /*Name=*/"",
+      /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
+      /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);
+
+  // Convert the difference to the expected result type by truncating or
+  // extending.
+  if (result->getType() != resultType)
+    result = builder.CreateIntCast(result, resultType, /*isSigned=*/true);
+
+  moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
+  return success();
+}
+
 /// Implementation of the dialect interface that converts operations belonging
 /// to the `ptr` dialect to LLVM IR.
 class PtrDialectLLVMIRTranslationInterface
@@ -369,6 +403,9 @@ class PtrDialectLLVMIRTranslationInterface
         .Case([&](PtrAddOp ptrAddOp) {
           return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
         })
+        .Case([&](PtrDiffOp ptrDiffOp) {
+          return convertPtrDiffOp(ptrDiffOp, builder, moduleTranslation);
+        })
         .Case([&](LoadOp loadOp) {
           return convertLoadOp(loadOp, builder, moduleTranslation);
         })
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index cc1eeb3cb5744..83e1c880650c5 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -70,3 +70,11 @@ func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>,
   %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
   return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
 }
+
+// -----
+
+func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> vector<8xi64> {
+  // expected-error@+1 {{the result to have the same container type as the operands when operands are shaped}}
+  %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
+  return %res : vector<8xi64>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 7b2254185f57c..0a906ad559e21 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -211,3 +211,31 @@ func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.p
   %addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
   return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
 }
+
+/// Test ptr_diff operations with scalar pointers
+func.func @ptr_diff_scalar_ops(%ptr1: !ptr.ptr<#ptr.generic_space>, %ptr2: !ptr.ptr<#ptr.generic_space>) -> (i64, index, i32) {
+  %diff_i64 = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i64
+  %diff_index = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> index
+  %diff_i32 = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i32
+  return %diff_i64, %diff_index, %diff_i32 : i64, index, i32
+}
+
+/// Test ptr_diff operations with vector pointers
+func.func @ptr_diff_vector_ops(%ptrs1: vector<4x!ptr.ptr<#ptr.generic_space>>, %ptrs2: vector<4x!ptr.ptr<#ptr.generic_space>>) -> (vector<4xi64>, vector<4xindex>) {
+  %diff_i64 = ptr.ptr_diff none %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+  %diff_index = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xindex>
+  return %diff_i64, %diff_index : vector<4xi64>, vector<4xindex>
+}
+
+/// Test ptr_diff operations with tensor pointers
+func.func @ptr_diff_tensor_ops(%ptrs1: tensor<8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> (tensor<8xi64>, tensor<8xi32>) {
+  %diff_i64 = ptr.ptr_diff nsw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi64>
+  %diff_i32 = ptr.ptr_diff nsw | nuw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32>
+  return %diff_i64, %diff_i32 : tensor<8xi64>, tensor<8xi32>
+}
+
+/// Test ptr_diff operations with 2D tensor pointers
+func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<4x8x!ptr.ptr<#ptr.generic_space>>) -> tensor<4x8xi64> {
+  %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
+  return %diff : tensor<4x8xi64>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 17954c9eaa8c6..ce5a13f680394 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -259,3 +259,99 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
   %res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
   llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
 }
+
+// CHECK-LABEL: define i64 @ptr_diff_scalar
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i32 @ptr_diff_scalar_i32
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32
+// CHECK-NEXT:   ret i32 %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i32 {
+  %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i32
+  llvm.return %diff : i32
+}
+
+// CHECK-LABEL: define <4 x i64> @ptr_diff_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64>
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64>
+// CHECK-NEXT:   %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret <4 x i64> %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<4x!ptr.ptr<#llvm.address_space<0>>>) -> vector<4xi64> {
+  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xi64>
+  llvm.return %diffs : vector<4xi64>
+}
+
+// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32
+// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64>
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64>
+// CHECK-NEXT:   %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32>
+// CHECK-NEXT:   ret <8 x i32> %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
+  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
+  llvm.return %diffs : vector<8xi32>
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_constants() {
+// CHECK-NEXT:   ret i64 4096
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_constants() -> i64 {
+  %ptr1 = ptr.constant #ptr.address<0x2000> : !ptr.ptr<#llvm.address_space<0>>
+  %ptr2 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<0>>
+  %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff nsw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff nsw | nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}

@llvmbot
Copy link
Member

llvmbot commented Sep 7, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Fabian Mora (fabianmcg)

Changes

Thi patch introduces the ptr.ptr_diff operation for computing pointer differences. The semantics of the operation are given by:

The `ptr_diff` operation computes the difference between two pointers,
returning an integer or index value representing the number of bytes
between them. This difference is always computed using signed arithmetic.

The operation supports both scalar and shaped types with value semantics:
- When both operands are scalar: produces a single difference value
- When both are shaped: performs element-wise subtraction,
  shapes must be the same

The operation also supports the following flags:
- `none`: No flags are set.
- `nuw`: No Unsigned Wrap, if the addition causes an unsigned overflow,
  the result is a poison value.
- `nsw`: No Signed Wrap, if the addition causes a signed overflow, the
  result is a poison value.

NOTE: The pointer difference is calculated using an integer type specified
by the data layout. The final result will be sign-extended or truncated to
fit the result type as necessary.

This patch also adds translation to LLVM IR hooks for the ptr_diff op. This translation uses the ptrtoaddr builder to compute only index bits difference.

Example:

llvm.func @<!-- -->ptr_diff_vector_i32(%ptrs1: vector&lt;8x!ptr.ptr&lt;#llvm.address_space&lt;0&gt;&gt;&gt;, %ptrs2: vector&lt;8x!ptr.ptr&lt;#llvm.address_space&lt;0&gt;&gt;&gt;) -&gt; vector&lt;8xi32&gt; {
  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector&lt;8x!ptr.ptr&lt;#llvm.address_space&lt;0&gt;&gt;&gt; -&gt; vector&lt;8xi32&gt;
  llvm.return %diffs : vector&lt;8xi32&gt;
}

Translation to LLVM IR:

define &lt;8 x i32&gt; @<!-- -->ptr_diff_vector_i32(&lt;8 x ptr&gt; %0, &lt;8 x ptr&gt; %1) {
  %3 = ptrtoint &lt;8 x ptr&gt; %0 to &lt;8 x i64&gt;
  %4 = ptrtoint &lt;8 x ptr&gt; %1 to &lt;8 x i64&gt;
  %5 = sub &lt;8 x i64&gt; %3, %4
  %6 = trunc &lt;8 x i64&gt; %5 to &lt;8 x i32&gt;
  ret &lt;8 x i32&gt; %6
}

Full diff: https://github.com/llvm/llvm-project/pull/157354.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td (+10)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+57)
  • (modified) mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp (+34)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp (+37)
  • (modified) mlir/test/Dialect/Ptr/invalid.mlir (+8)
  • (modified) mlir/test/Dialect/Ptr/ops.mlir (+28)
  • (modified) mlir/test/Target/LLVMIR/ptr.mlir (+96)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
index c169f48e573d0..c97bd04d32896 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
@@ -79,4 +79,14 @@ def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [
   let cppNamespace = "::mlir::ptr";
 }
 
+//===----------------------------------------------------------------------===//
+// Ptr diff flags enum properties.
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffFlags : I8BitEnum<"PtrDiffFlags", "Pointer difference flags", [
+    I8BitEnumCase<"none", 0>, I8BitEnumCase<"nuw", 1>, I8BitEnumCase<"nsw", 2>
+  ]> {
+  let cppNamespace = "::mlir::ptr";
+}
+
 #endif // PTR_ENUMS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 468a3004d5c62..85902fdf7159e 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -415,6 +415,63 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
+    Pure, AllTypesMatch<["lhs", "rhs"]>, SameOperandsAndResultShape
+  ]> {
+  let summary = "Pointer difference operation";
+  let description = [{
+    The `ptr_diff` operation computes the difference between two pointers,
+    returning an integer or index value representing the number of bytes
+    between them. This difference is always computed using signed arithmetic.
+
+    The operation supports both scalar and shaped types with value semantics:
+    - When both operands are scalar: produces a single difference value
+    - When both are shaped: performs element-wise subtraction,
+      shapes must be the same
+
+    The operation also supports the following flags:
+    - `none`: No flags are set.
+    - `nuw`: No Unsigned Wrap, if the addition causes an unsigned overflow,
+      the result is a poison value.
+    - `nsw`: No Signed Wrap, if the addition causes a signed overflow, the
+      result is a poison value.
+
+    NOTE: The pointer difference is calculated using an integer type specified
+    by the data layout. The final result will be sign-extended or truncated to
+    fit the result type as necessary.
+
+    Example:
+
+    ```mlir
+    // Scalar pointers
+    %diff = ptr.ptr_diff %p1, %p2 : !ptr.ptr<#ptr.generic_space> -> i64
+
+    // Shaped pointers
+    %diffs = ptr.ptr_diff nsw %ptrs1, %ptrs2 :
+      vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+    ```
+  }];
+  let arguments = (ins
+    Ptr_PtrLikeType:$lhs, Ptr_PtrLikeType:$rhs,
+    DefaultValuedProp<EnumProp<Ptr_PtrDiffFlags>, "PtrDiffFlags::none">:$flags
+  );
+  let results = (outs Ptr_IntLikeType:$result);
+  let assemblyFormat = [{
+    ($flags^)? $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
+  }];
+  let extraClassDeclaration = [{
+    /// Returns the operand's ptr type.
+    ptr::PtrType getPtrType();
+    /// Returns the result's underlying int type.
+    Type getIntType();
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index f0209af8a1ca3..51f25f755a8a6 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -391,6 +392,39 @@ LogicalResult PtrAddOp::inferReturnTypes(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// PtrDiffOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PtrDiffOp::verify() {
+  // If the operands are not shaped early exit.
+  if (!isa<ShapedType>(getLhs().getType()))
+    return success();
+
+  // Just check the container type matches, `SameOperandsAndResultShape` handles
+  // the actual shape.
+  if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) {
+    return emitError() << "expected the result to have the same container "
+                          "type as the operands when operands are shaped";
+  }
+
+  return success();
+}
+
+ptr::PtrType PtrDiffOp::getPtrType() {
+  Type lhsType = getLhs().getType();
+  if (auto shapedType = dyn_cast<ShapedType>(lhsType))
+    return cast<ptr::PtrType>(shapedType.getElementType());
+  return cast<ptr::PtrType>(lhsType);
+}
+
+Type PtrDiffOp::getIntType() {
+  Type resultType = getResult().getType();
+  if (auto shapedType = dyn_cast<ShapedType>(resultType))
+    return shapedType.getElementType();
+  return resultType;
+}
+
 //===----------------------------------------------------------------------===//
 // ToPtrOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 11b921de21596..556d8762ade6b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -349,6 +349,40 @@ convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+/// Convert ptr.ptr_diff operation
+static LogicalResult
+convertPtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
+                 LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
+  llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());
+
+  if (!lhs || !rhs)
+    return ptrDiffOp.emitError("Failed to lookup operands");
+
+  // Convert result type to LLVM type
+  llvm::Type *resultType =
+      moduleTranslation.convertType(ptrDiffOp.getResult().getType());
+  if (!resultType)
+    return ptrDiffOp.emitError("Failed to convert result type");
+
+  PtrDiffFlags flags = ptrDiffOp.getFlags();
+
+  // Convert both pointers to integers using ptrtoaddr, and compute the
+  // difference: lhs - rhs
+  llvm::Value *result = builder.CreateSub(
+      builder.CreatePtrToAddr(lhs), builder.CreatePtrToAddr(rhs), /*Name=*/"",
+      /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
+      /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);
+
+  // Convert the difference to the expected result type by truncating or
+  // extending.
+  if (result->getType() != resultType)
+    result = builder.CreateIntCast(result, resultType, /*isSigned=*/true);
+
+  moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
+  return success();
+}
+
 /// Implementation of the dialect interface that converts operations belonging
 /// to the `ptr` dialect to LLVM IR.
 class PtrDialectLLVMIRTranslationInterface
@@ -369,6 +403,9 @@ class PtrDialectLLVMIRTranslationInterface
         .Case([&](PtrAddOp ptrAddOp) {
           return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
         })
+        .Case([&](PtrDiffOp ptrDiffOp) {
+          return convertPtrDiffOp(ptrDiffOp, builder, moduleTranslation);
+        })
         .Case([&](LoadOp loadOp) {
           return convertLoadOp(loadOp, builder, moduleTranslation);
         })
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index cc1eeb3cb5744..83e1c880650c5 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -70,3 +70,11 @@ func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>,
   %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
   return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
 }
+
+// -----
+
+func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> vector<8xi64> {
+  // expected-error@+1 {{the result to have the same container type as the operands when operands are shaped}}
+  %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
+  return %res : vector<8xi64>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 7b2254185f57c..0a906ad559e21 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -211,3 +211,31 @@ func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.p
   %addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
   return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
 }
+
+/// Test ptr_diff operations with scalar pointers
+func.func @ptr_diff_scalar_ops(%ptr1: !ptr.ptr<#ptr.generic_space>, %ptr2: !ptr.ptr<#ptr.generic_space>) -> (i64, index, i32) {
+  %diff_i64 = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i64
+  %diff_index = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> index
+  %diff_i32 = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i32
+  return %diff_i64, %diff_index, %diff_i32 : i64, index, i32
+}
+
+/// Test ptr_diff operations with vector pointers
+func.func @ptr_diff_vector_ops(%ptrs1: vector<4x!ptr.ptr<#ptr.generic_space>>, %ptrs2: vector<4x!ptr.ptr<#ptr.generic_space>>) -> (vector<4xi64>, vector<4xindex>) {
+  %diff_i64 = ptr.ptr_diff none %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64>
+  %diff_index = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xindex>
+  return %diff_i64, %diff_index : vector<4xi64>, vector<4xindex>
+}
+
+/// Test ptr_diff operations with tensor pointers
+func.func @ptr_diff_tensor_ops(%ptrs1: tensor<8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> (tensor<8xi64>, tensor<8xi32>) {
+  %diff_i64 = ptr.ptr_diff nsw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi64>
+  %diff_i32 = ptr.ptr_diff nsw | nuw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32>
+  return %diff_i64, %diff_i32 : tensor<8xi64>, tensor<8xi32>
+}
+
+/// Test ptr_diff operations with 2D tensor pointers
+func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<4x8x!ptr.ptr<#ptr.generic_space>>) -> tensor<4x8xi64> {
+  %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
+  return %diff : tensor<4x8xi64>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 17954c9eaa8c6..ce5a13f680394 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -259,3 +259,99 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
   %res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
   llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
 }
+
+// CHECK-LABEL: define i64 @ptr_diff_scalar
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i32 @ptr_diff_scalar_i32
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32
+// CHECK-NEXT:   ret i32 %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i32 {
+  %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i32
+  llvm.return %diff : i32
+}
+
+// CHECK-LABEL: define <4 x i64> @ptr_diff_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64>
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64>
+// CHECK-NEXT:   %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret <4 x i64> %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<4x!ptr.ptr<#llvm.address_space<0>>>) -> vector<4xi64> {
+  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xi64>
+  llvm.return %diffs : vector<4xi64>
+}
+
+// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32
+// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64>
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64>
+// CHECK-NEXT:   %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32>
+// CHECK-NEXT:   ret <8 x i32> %[[TRUNC]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
+  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
+  llvm.return %diffs : vector<8xi32>
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_constants() {
+// CHECK-NEXT:   ret i64 4096
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_constants() -> i64 {
+  %ptr1 = ptr.constant #ptr.address<0x2000> : !ptr.ptr<#llvm.address_space<0>>
+  %ptr2 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<0>>
+  %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff nsw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}
+
+// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw
+// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
+// CHECK-NEXT:   %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
+// CHECK-NEXT:   %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT:   %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]]
+// CHECK-NEXT:   ret i64 %[[DIFF]]
+// CHECK-NEXT: }
+llvm.func @ptr_diff_with_flags_nsw_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 {
+  %diff = ptr.ptr_diff nsw | nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64
+  llvm.return %diff : i64
+}

Copilot

This comment was marked as outdated.

Copy link

github-actions bot commented Sep 7, 2025

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

@fabianmcg fabianmcg force-pushed the users/fabianmcg/ptr-diffop branch from f467d25 to 5e4b984 Compare September 7, 2025 18:02
@fabianmcg fabianmcg added mlir:ptr MLIR ptr dialect and removed mlir:llvm mlir labels Sep 7, 2025
@fabianmcg fabianmcg requested a review from Copilot September 7, 2025 18:02
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces the ptr.ptr_diff operation to the Ptr dialect for computing pointer differences with support for various flags and shaped types. The operation returns the byte difference between two pointers as an integer or index value.

Key changes include:

  • Adding the ptr.ptr_diff operation with overflow flags (nsw, nuw)
  • Supporting both scalar and vector/tensor shaped types with element-wise semantics
  • Implementing LLVM IR translation that converts pointers to integers and performs subtraction

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td Defines PtrDiffFlags enum for overflow behavior flags
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td Defines the ptr_diff operation with its syntax and semantics
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp Implements verification logic and helper methods for ptr_diff
mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp Adds LLVM IR translation for ptr_diff operation
mlir/test/Dialect/Ptr/ops.mlir Tests valid ptr_diff operations with various types and flags
mlir/test/Dialect/Ptr/invalid.mlir Tests error handling for invalid ptr_diff usage
mlir/test/Target/LLVMIR/ptr.mlir Tests LLVM IR generation for ptr_diff operations

@fabianmcg fabianmcg force-pushed the users/fabianmcg/ptr-constantop branch from 362ce4b to 966d174 Compare September 14, 2025 13:56
@fabianmcg fabianmcg force-pushed the users/fabianmcg/ptr-diffop branch from 5e4b984 to a144c1b Compare September 14, 2025 14:33
Base automatically changed from users/fabianmcg/ptr-constantop to main September 14, 2025 15:45
Thi patch introduces the `ptr.ptr_diff` operation for computing pointer
differences. The semantics of the operation are given by:
```
The `ptr_diff` operation computes the difference between two pointers,
returning an integer or index value representing the number of bytes
between them. This difference is always computed using signed arithmetic.

The operation supports both scalar and shaped types with value semantics:
- When both operands are scalar: produces a single difference value
- When both are shaped: performs element-wise subtraction,
  shapes must be the same

The operation also supports the following flags:
- `none`: No flags are set.
- `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow,
  the result is a poison value.
- `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the
  result is a poison value.

NOTE: The pointer difference is calculated using an integer type specified
by the data layout. The final result will be sign-extended or truncated to
fit the result type as necessary.
```

This patch also adds translation to LLVM IR hooks for the `ptr_diff` op.
This translation uses the `ptrtoaddr` builder to compute only index
bits difference.

Example:
```mlir
llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
  llvm.return %diffs : vector<8xi32>
}
```
Translation to LLVM IR:
```llvm
define <8 x i32> @ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) {
  %3 = ptrtoint <8 x ptr> %0 to <8 x i64>
  %4 = ptrtoint <8 x ptr> %1 to <8 x i64>
  %5 = sub <8 x i64> %3, %4
  %6 = trunc <8 x i64> %5 to <8 x i32>
  ret <8 x i32> %6
}
```
@fabianmcg fabianmcg force-pushed the users/fabianmcg/ptr-diffop branch from a144c1b to c25dd35 Compare September 14, 2025 17:08
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the behavior for an address space mismatch?

#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringExtras.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an accidental include.

@fabianmcg
Copy link
Contributor Author

What is the behavior for an address space mismatch?

The lhs and rhs types must match, so there’s never an address space mismatch.

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, makes sense then.
LGTM!

let description = [{
The `ptr_diff` operation computes the difference between two pointers,
returning an integer or index value representing the number of bytes
between them. This difference is always computed using signed arithmetic.
Copy link
Collaborator

@joker-eph joker-eph Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
between them. This difference is always computed using signed arithmetic.
between them. This difference is always computed using signed arithmetic.

What does it mean? Can you provide an example of "signed" vs "unsigned" arithmetic here?
Can a pointer itself be negative?

Copy link
Contributor Author

@fabianmcg fabianmcg Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Negative pointers no, but take: (ptr + 16) - (ptr + 32), the result will be negative.

I took as rationale the existence of ptrdiff_t in C https://en.cppreference.com/w/c/types/ptrdiff_t.html

I'll clarify the comment, because it's not clear.

Copy link
Collaborator

@joker-eph joker-eph Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be mindful that C is a language with signed type, while we're operating in a signless IR here.

The fact that the result can be negative is an obvious aspect of the operation being a subtraction, it has not much to do with "signed arithmetic".
I'd look at arith.subi for inspiration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the comment because I already have the phrase:

The final result will be sign-extended or truncated to
fit the result type as necessary.

in the description.

@fabianmcg fabianmcg requested a review from joker-eph September 15, 2025 16:26
return shapedType.getElementType();
return resultType;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get a folder for the subtraction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, in general I need to improve the folders and canonicalizers of the entire dialect, so how about a new PR for that?

@fabianmcg fabianmcg merged commit e3aa00e into main Sep 24, 2025
9 checks passed
@fabianmcg fabianmcg deleted the users/fabianmcg/ptr-diffop branch September 24, 2025 17:09
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
Thi patch introduces the `ptr.ptr_diff` operation for computing pointer
differences. The semantics of the operation are given by:
```
The `ptr_diff` operation computes the difference between two pointers,
returning an integer or index value representing the number of bytes
between them.

The operation supports both scalar and shaped types with value semantics:
- When both operands are scalar: produces a single difference value
- When both are shaped: performs element-wise subtraction,
  shapes must be the same

The operation also supports the following flags:
- `none`: No flags are set.
- `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow,
  the result is a poison value.
- `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the
  result is a poison value.

NOTE: The pointer difference is calculated using an integer type specified
by the data layout. The final result will be sign-extended or truncated to
fit the result type as necessary.
```

This patch also adds translation to LLVM IR hooks for the `ptr_diff` op.
This translation uses the `ptrtoaddr` builder to compute only index bits
difference.

Example:
```mlir
llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> {
  %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
  llvm.return %diffs : vector<8xi32>
}
```
Translation to LLVM IR:
```llvm
define <8 x i32> @ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) {
  %3 = ptrtoint <8 x ptr> %0 to <8 x i64>
  %4 = ptrtoint <8 x ptr> %1 to <8 x i64>
  %5 = sub <8 x i64> %3, %4
  %6 = trunc <8 x i64> %5 to <8 x i32>
  ret <8 x i32> %6
}
```

---------

Co-authored-by: Mehdi Amini <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants