Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,28 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) {
return emitPointerArithmetic(CGF, Ops, /*isSubtraction=*/true);

// Otherwise, this is a pointer subtraction

mlir::Value lhs = Ops.LHS; // pointer
mlir::Value rhs = Ops.RHS; // pointer
auto loc = CGF.getLoc(Ops.Loc);

mlir::Type lhsTy = lhs.getType();
mlir::Type rhsTy = rhs.getType();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should follow OG skeleton here, similar names if possible would be helpful:

// Okay, figure out the element size.
  const BinaryOperator *expr = cast<BinaryOperator>(op.E);
  QualType elementType = expr->getLHS()->getType()->getPointeeType();
 ...
  // For a variable-length array, this is going to be non-constant.
  if (const VariableArrayType *vla
        = CGF.getContext().getAsVariableArrayType(elementType)) {
    llvm_unrecheable("NYI");
...
  else {
    ...
  }

auto lhsPtrTy = mlir::dyn_cast<cir::PointerType>(lhsTy);
auto rhsPtrTy = mlir::dyn_cast<cir::PointerType>(rhsTy);

if (lhsPtrTy && rhsPtrTy) {
auto lhsAS = lhsPtrTy.getAddrSpace();
auto rhsAS = rhsPtrTy.getAddrSpace();

if (lhsAS != rhsAS) {
// Different address spaces → use addrspacecast
rhs = Builder.createAddrSpaceCast(rhs, lhsPtrTy);
} else if (lhsPtrTy != rhsPtrTy) {
// Same addrspace but different pointee/type → bitcast is fine
rhs = Builder.createBitcast(rhs, lhsPtrTy);
Copy link

@mahmood82 mahmood82 Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, Can you please try createPtrBitcast

I have a local patch for that function to allow passing AS, I'll upload it ASAP:

----------- clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h -----------
index b7b599e1289e..fd096ff97eb7 100644
@@ -574,7 +574,7 @@ public:
 
   mlir::Value createPtrBitcast(mlir::Value src, mlir::Type newPointeeTy) {
     assert(mlir::isa<cir::PointerType>(src.getType()) && "expected ptr src");
-    return createBitcast(src, getPointerTo(newPointeeTy));
+    return createBitcast(src, getPointerTo(newPointeeTy, mlir::cast<cir::PointerType>(src.getType()).getAddrSpace()));
   }

In addition I have another patch to fix the case of different AS and different data type. I'll upload them ASAP and let you know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to use createPtrBitcast. When you patch lands and fixes AS I will be happy to revise.

Copy link
Contributor Author

@koparasy koparasy Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I took a closer look at createPtrBitcast, and in this particular spot I already have the full destination pointer type (lhsPtrTy) available, including the address space information. Using createBitcast(rhs, lhsPtrTy) lets me preserve that exact type without reconstructing it.

createPtrBitcast is super useful in places where we only know the new pointee type, but here we already computed the exact pointer type we want to cast to, so sticking to createBitcast keeps the intent clearer.

Happy to switch if we move toward a uniform helper style later, but for this case preserving the full cir::PointerType seems cleaner.

Thanks again for taking a look — really appreciate the review!

}
}
// Do the raw subtraction part.
//
// TODO(cir): note for LLVM lowering out of this; when expanding this into
Expand All @@ -1523,7 +1544,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) {
// See more in `EmitSub` in CGExprScalar.cpp.
assert(!cir::MissingFeatures::llvmLoweringPtrDiffConsidersPointee());
return cir::PtrDiffOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.PtrDiffTy,
Ops.LHS, Ops.RHS);
lhs, rhs);
}

mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) {
Expand Down
60 changes: 60 additions & 0 deletions clang/test/CIR/CodeGen/HIP/ptr-diff.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "cuda.h"

// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
// RUN: -fcuda-is-device -fhip-new-launch-api \
// RUN: -I%S/../Inputs/ -emit-cir %s -o %t.ll
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.ll %s

// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
// RUN: -fcuda-is-device -fhip-new-launch-api \
// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s

// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip \
// RUN: -fcuda-is-device -fhip-new-launch-api \
// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s

__device__ int ptr_diff() {
const char c_str[] = "c-string";
const char* len = c_str;
return c_str - len;
}


// CIR-DEVICE: %[[#LenLocalAddr:]] = cir.alloca !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>, ["len", init]
// CIR-DEVICE: %[[#GlobalPtr:]] = cir.get_global @_ZZ8ptr_diffvE5c_str : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)>
// CIR-DEVICE: %[[#CastDecay:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)>
// CIR-DEVICE: %[[#LenLocalAddrCast:]] = cir.cast bitcast %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>> -> !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>>
// CIR-DEVICE: cir.store align(8) %[[#CastDecay]], %[[#LenLocalAddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)>, !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>>
// CIR-DEVICE: %[[#CStr:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)> -> !cir.ptr<!s8i, addrspace(offload_constant)>
// CIR-DEVICE: %[[#LoadedLenAddr:]] = cir.load align(8) %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i> loc(#loc7)
// CIR-DEVICE: %[[#AddrCast:]] = cir.cast address_space %[[#LoadedLenAddr]] : !cir.ptr<!s8i> -> !cir.ptr<!s8i, addrspace(offload_constant)>
// CIR-DEVICE: %[[#DIFF:]] = cir.ptr_diff %[[#CStr]], %[[#AddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)>

// LLVM-DEVICE: define dso_local i32 @_Z8ptr_diffv()
// LLVM-DEVICE: %[[#GlobalPtrAddr:]] = alloca i32, i64 1, align 4, addrspace(5)
// LLVM-DEVICE: %[[#GlobalPtrCast:]] = addrspacecast ptr addrspace(5) %[[#GlobalPtrAddr]] to ptr
// LLVM-DEVICE: %[[#LenLocalAddr:]] = alloca ptr, i64 1, align 8, addrspace(5)
// LLVM-DEVICE: %[[#LenLocalAddrCast:]] = addrspacecast ptr addrspace(5) %[[#LenLocalAddr]] to ptr
// LLVM-DEVICE: store ptr addrspace(4) @_ZZ8ptr_diffvE5c_str, ptr %[[#LenLocalAddrCast]], align 8
// LLVM-DEVICE: %[[#LoadedAddr:]] = load ptr, ptr %[[#LenLocalAddrCast]], align 8
// LLVM-DEVICE: %[[#CastedVal:]] = addrspacecast ptr %[[#LoadedAddr]] to ptr addrspace(4)
// LLVM-DEVICE: %[[#IntVal:]] = ptrtoint ptr addrspace(4) %[[#CastedVal]] to i64
// LLVM-DEVICE: %[[#SubVal:]] = sub i64 ptrtoint (ptr addrspace(4) @_ZZ8ptr_diffvE5c_str to i64), %[[#IntVal]]

// OGCG-DEVICE: define dso_local noundef i32 @_Z8ptr_diffv() #0
// OGCG-DEVICE: %[[RETVAL:.*]] = alloca i32, align 4, addrspace(5)
// OGCG-DEVICE: %[[C_STR:.*]] = alloca [9 x i8], align 1, addrspace(5)
// OGCG-DEVICE: %[[LEN:.*]] = alloca ptr, align 8, addrspace(5)
// OGCG-DEVICE: %[[RETVAL_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RETVAL]] to ptr
// OGCG-DEVICE: %[[C_STR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[C_STR]] to ptr
// OGCG-DEVICE: %[[LEN_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[LEN]] to ptr
// OGCG-DEVICE: %[[ARRAYDECAY:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0
// OGCG-DEVICE: store ptr %[[ARRAYDECAY]], ptr %[[LEN_ASCAST]], align 8
// OGCG-DEVICE: %[[ARRAYDECAY1:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0
// OGCG-DEVICE: %[[LOADED:.*]] = load ptr, ptr %[[LEN_ASCAST]], align 8
// OGCG-DEVICE: %[[LHS:.*]] = ptrtoint ptr %[[ARRAYDECAY1]] to i64
// OGCG-DEVICE: %[[RHS:.*]] = ptrtoint ptr %[[LOADED]] to i64
// OGCG-DEVICE: %[[SUB:.*]] = sub i64 %[[LHS]], %[[RHS]]
// OGCG-DEVICE: %[[CONV:.*]] = trunc i64 %[[SUB]] to i32