From 607a3e0a4d697a8c1562c62be126e5fc32fdd330 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Sun, 7 Sep 2025 17:33:04 +0000 Subject: [PATCH 1/5] [mlir][ptr] Add `ptr.ptr_diff` op Thi patch introduces the `ptr.ptr_diff` operation for computing pointer differences. The semantics of the operation are given by: ``` The `ptr_diff` operation computes the difference between two pointers, returning an integer or index value representing the number of bytes between them. This difference is always computed using signed arithmetic. The operation supports both scalar and shaped types with value semantics: - When both operands are scalar: produces a single difference value - When both are shaped: performs element-wise subtraction, shapes must be the same The operation also supports the following flags: - `none`: No flags are set. - `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow, the result is a poison value. - `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the result is a poison value. NOTE: The pointer difference is calculated using an integer type specified by the data layout. The final result will be sign-extended or truncated to fit the result type as necessary. ``` This patch also adds translation to LLVM IR hooks for the `ptr_diff` op. This translation uses the `ptrtoaddr` builder to compute only index bits difference. Example: ```mlir llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> { %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32> llvm.return %diffs : vector<8xi32> } ``` Translation to LLVM IR: ```llvm define <8 x i32> @ptr_diff_vector_i32(<8 x ptr> %0, <8 x ptr> %1) { %3 = ptrtoint <8 x ptr> %0 to <8 x i64> %4 = ptrtoint <8 x ptr> %1 to <8 x i64> %5 = sub <8 x i64> %3, %4 %6 = trunc <8 x i64> %5 to <8 x i32> ret <8 x i32> %6 } ``` --- mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td | 10 ++ mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 57 +++++++++++ mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 34 +++++++ .../Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 37 +++++++ mlir/test/Dialect/Ptr/invalid.mlir | 8 ++ mlir/test/Dialect/Ptr/ops.mlir | 28 ++++++ mlir/test/Target/LLVMIR/ptr.mlir | 96 +++++++++++++++++++ 7 files changed, 270 insertions(+) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td index c169f48e573d0..c97bd04d32896 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td @@ -79,4 +79,14 @@ def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [ let cppNamespace = "::mlir::ptr"; } +//===----------------------------------------------------------------------===// +// Ptr diff flags enum properties. +//===----------------------------------------------------------------------===// + +def Ptr_PtrDiffFlags : I8BitEnum<"PtrDiffFlags", "Pointer difference flags", [ + I8BitEnumCase<"none", 0>, I8BitEnumCase<"nuw", 1>, I8BitEnumCase<"nsw", 2> + ]> { + let cppNamespace = "::mlir::ptr"; +} + #endif // PTR_ENUMS diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 468a3004d5c62..7735210e809e3 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -415,6 +415,63 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ }]; } +//===----------------------------------------------------------------------===// +// PtrDiffOp +//===----------------------------------------------------------------------===// + +def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [ + Pure, AllTypesMatch<["lhs", "rhs"]>, SameOperandsAndResultShape + ]> { + let summary = "Pointer difference operation"; + let description = [{ + The `ptr_diff` operation computes the difference between two pointers, + returning an integer or index value representing the number of bytes + between them. This difference is always computed using signed arithmetic. + + The operation supports both scalar and shaped types with value semantics: + - When both operands are scalar: produces a single difference value + - When both are shaped: performs element-wise subtraction, + shapes must be the same + + The operation also supports the following flags: + - `none`: No flags are set. + - `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow, + the result is a poison value. + - `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the + result is a poison value. + + NOTE: The pointer difference is calculated using an integer type specified + by the data layout. The final result will be sign-extended or truncated to + fit the result type as necessary. + + Example: + + ```mlir + // Scalar pointers + %diff = ptr.ptr_diff %p1, %p2 : !ptr.ptr<#ptr.generic_space> -> i64 + + // Shaped pointers + %diffs = ptr.ptr_diff nsw %ptrs1, %ptrs2 : + vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64> + ``` + }]; + let arguments = (ins + Ptr_PtrLikeType:$lhs, Ptr_PtrLikeType:$rhs, + DefaultValuedProp, "PtrDiffFlags::none">:$flags + ); + let results = (outs Ptr_IntLikeType:$result); + let assemblyFormat = [{ + ($flags^)? $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result) + }]; + let extraClassDeclaration = [{ + /// Returns the operand's ptr type. + ptr::PtrType getPtrType(); + /// Returns the result's underlying int type. + Type getIntType(); + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index f0209af8a1ca3..51f25f755a8a6 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -391,6 +392,39 @@ LogicalResult PtrAddOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// PtrDiffOp +//===----------------------------------------------------------------------===// + +LogicalResult PtrDiffOp::verify() { + // If the operands are not shaped early exit. + if (!isa(getLhs().getType())) + return success(); + + // Just check the container type matches, `SameOperandsAndResultShape` handles + // the actual shape. + if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) { + return emitError() << "expected the result to have the same container " + "type as the operands when operands are shaped"; + } + + return success(); +} + +ptr::PtrType PtrDiffOp::getPtrType() { + Type lhsType = getLhs().getType(); + if (auto shapedType = dyn_cast(lhsType)) + return cast(shapedType.getElementType()); + return cast(lhsType); +} + +Type PtrDiffOp::getIntType() { + Type resultType = getResult().getType(); + if (auto shapedType = dyn_cast(resultType)) + return shapedType.getElementType(); + return resultType; +} + //===----------------------------------------------------------------------===// // ToPtrOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp index 7e610cd42e931..550d42f8d3c86 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -351,6 +351,40 @@ translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder, return success(); } +/// Convert ptr.ptr_diff operation +static LogicalResult +convertPtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs()); + llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs()); + + if (!lhs || !rhs) + return ptrDiffOp.emitError("Failed to lookup operands"); + + // Convert result type to LLVM type + llvm::Type *resultType = + moduleTranslation.convertType(ptrDiffOp.getResult().getType()); + if (!resultType) + return ptrDiffOp.emitError("Failed to convert result type"); + + PtrDiffFlags flags = ptrDiffOp.getFlags(); + + // Convert both pointers to integers using ptrtoaddr, and compute the + // difference: lhs - rhs + llvm::Value *result = builder.CreateSub( + builder.CreatePtrToAddr(lhs), builder.CreatePtrToAddr(rhs), /*Name=*/"", + /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw, + /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw); + + // Convert the difference to the expected result type by truncating or + // extending. + if (result->getType() != resultType) + result = builder.CreateIntCast(result, resultType, /*isSigned=*/true); + + moduleTranslation.mapValue(ptrDiffOp.getResult(), result); + return success(); +} + /// Implementation of the dialect interface that translates operations belonging /// to the `ptr` dialect to LLVM IR. class PtrDialectLLVMIRTranslationInterface @@ -371,6 +405,9 @@ class PtrDialectLLVMIRTranslationInterface .Case([&](PtrAddOp ptrAddOp) { return translatePtrAddOp(ptrAddOp, builder, moduleTranslation); }) + .Case([&](PtrDiffOp ptrDiffOp) { + return convertPtrDiffOp(ptrDiffOp, builder, moduleTranslation); + }) .Case([&](LoadOp loadOp) { return translateLoadOp(loadOp, builder, moduleTranslation); }) diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir index cc1eeb3cb5744..83e1c880650c5 100644 --- a/mlir/test/Dialect/Ptr/invalid.mlir +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -70,3 +70,11 @@ func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64> return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> } + +// ----- + +func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> vector<8xi64> { + // expected-error@+1 {{the result to have the same container type as the operands when operands are shaped}} + %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64> + return %res : vector<8xi64> +} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index 7b2254185f57c..0a906ad559e21 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -211,3 +211,31 @@ func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.p %addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>> return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>> } + +/// Test ptr_diff operations with scalar pointers +func.func @ptr_diff_scalar_ops(%ptr1: !ptr.ptr<#ptr.generic_space>, %ptr2: !ptr.ptr<#ptr.generic_space>) -> (i64, index, i32) { + %diff_i64 = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i64 + %diff_index = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> index + %diff_i32 = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i32 + return %diff_i64, %diff_index, %diff_i32 : i64, index, i32 +} + +/// Test ptr_diff operations with vector pointers +func.func @ptr_diff_vector_ops(%ptrs1: vector<4x!ptr.ptr<#ptr.generic_space>>, %ptrs2: vector<4x!ptr.ptr<#ptr.generic_space>>) -> (vector<4xi64>, vector<4xindex>) { + %diff_i64 = ptr.ptr_diff none %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64> + %diff_index = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xindex> + return %diff_i64, %diff_index : vector<4xi64>, vector<4xindex> +} + +/// Test ptr_diff operations with tensor pointers +func.func @ptr_diff_tensor_ops(%ptrs1: tensor<8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> (tensor<8xi64>, tensor<8xi32>) { + %diff_i64 = ptr.ptr_diff nsw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi64> + %diff_i32 = ptr.ptr_diff nsw | nuw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32> + return %diff_i64, %diff_i32 : tensor<8xi64>, tensor<8xi32> +} + +/// Test ptr_diff operations with 2D tensor pointers +func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<4x8x!ptr.ptr<#ptr.generic_space>>) -> tensor<4x8xi64> { + %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64> + return %diff : tensor<4x8xi64> +} diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir index 2fa794130ec52..e2687e52ece57 100644 --- a/mlir/test/Target/LLVMIR/ptr.mlir +++ b/mlir/test/Target/LLVMIR/ptr.mlir @@ -281,3 +281,99 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> { %res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32 llvm.return %res : !ptr.ptr<#llvm.address_space<0>> } + +// CHECK-LABEL: define i64 @ptr_diff_scalar +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i32 @ptr_diff_scalar_i32 +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32 +// CHECK-NEXT: ret i32 %[[TRUNC]] +// CHECK-NEXT: } +llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i32 { + %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i32 + llvm.return %diff : i32 +} + +// CHECK-LABEL: define <4 x i64> @ptr_diff_vector +// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64> +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64> +// CHECK-NEXT: %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret <4 x i64> %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<4x!ptr.ptr<#llvm.address_space<0>>>) -> vector<4xi64> { + %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xi64> + llvm.return %diffs : vector<4xi64> +} + +// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32 +// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64> +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64> +// CHECK-NEXT: %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32> +// CHECK-NEXT: ret <8 x i32> %[[TRUNC]] +// CHECK-NEXT: } +llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> { + %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32> + llvm.return %diffs : vector<8xi32> +} + +// CHECK-LABEL: define i64 @ptr_diff_with_constants() { +// CHECK-NEXT: ret i64 4096 +// CHECK-NEXT: } +llvm.func @ptr_diff_with_constants() -> i64 { + %ptr1 = ptr.constant #ptr.address<0x2000> : !ptr.ptr<#llvm.address_space<0>> + %ptr2 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<0>> + %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff nsw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_with_flags_nsw_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff nsw | nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} From 5ea857101e564d493dc11c516a98c3152a6ec204 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Sun, 14 Sep 2025 14:32:59 +0000 Subject: [PATCH 2/5] fix non-deterministic IR generation --- mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp index 550d42f8d3c86..3ff558ba91bb1 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -371,8 +371,10 @@ convertPtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder, // Convert both pointers to integers using ptrtoaddr, and compute the // difference: lhs - rhs + llvm::Value *llLhs = builder.CreatePtrToAddr(lhs); + llvm::Value *llRhs = builder.CreatePtrToAddr(rhs); llvm::Value *result = builder.CreateSub( - builder.CreatePtrToAddr(lhs), builder.CreatePtrToAddr(rhs), /*Name=*/"", + llLhs, llRhs, /*Name=*/"", /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw, /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw); From c25dd35390c59203d8f1755d6b1e7c4c12b3c8b5 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Sun, 14 Sep 2025 17:08:06 +0000 Subject: [PATCH 3/5] use translate instead of convert --- .../LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp index 3ff558ba91bb1..8d6fffcca45f2 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -351,21 +351,21 @@ translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.ptr_diff operation +/// Translate ptr.ptr_diff operation operation to LLVM IR. static LogicalResult -convertPtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +translatePtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs()); llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs()); if (!lhs || !rhs) return ptrDiffOp.emitError("Failed to lookup operands"); - // Convert result type to LLVM type + // Translate result type to LLVM type llvm::Type *resultType = moduleTranslation.convertType(ptrDiffOp.getResult().getType()); if (!resultType) - return ptrDiffOp.emitError("Failed to convert result type"); + return ptrDiffOp.emitError("Failed to translate result type"); PtrDiffFlags flags = ptrDiffOp.getFlags(); @@ -408,7 +408,7 @@ class PtrDialectLLVMIRTranslationInterface return translatePtrAddOp(ptrAddOp, builder, moduleTranslation); }) .Case([&](PtrDiffOp ptrDiffOp) { - return convertPtrDiffOp(ptrDiffOp, builder, moduleTranslation); + return translatePtrDiffOp(ptrDiffOp, builder, moduleTranslation); }) .Case([&](LoadOp loadOp) { return translateLoadOp(loadOp, builder, moduleTranslation); From f4a8eb2f7708a27762b67700304ae7938e630894 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Mon, 15 Sep 2025 09:04:10 -0400 Subject: [PATCH 4/5] Update mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td Co-authored-by: Mehdi Amini --- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 7735210e809e3..c489f1fc6e2df 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -435,8 +435,8 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [ The operation also supports the following flags: - `none`: No flags are set. - - `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow, - the result is a poison value. + - `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow + (that is: the result would be negative), the result is a poison value. - `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the result is a poison value. From 3fca67dab2044489ddd06ea5ab945e48d3d7e090 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Mon, 15 Sep 2025 12:25:17 -0400 Subject: [PATCH 5/5] Update PtrOps.td --- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index c489f1fc6e2df..e14f64330c294 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -426,7 +426,7 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [ let description = [{ The `ptr_diff` operation computes the difference between two pointers, returning an integer or index value representing the number of bytes - between them. This difference is always computed using signed arithmetic. + between them. The operation supports both scalar and shaped types with value semantics: - When both operands are scalar: produces a single difference value