diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index b6b114f0e4b9..ffdd7db42630 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1514,7 +1514,28 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) { return emitPointerArithmetic(CGF, Ops, /*isSubtraction=*/true); // Otherwise, this is a pointer subtraction - + mlir::Value lhs = Ops.LHS; // pointer + mlir::Value rhs = Ops.RHS; // pointer + auto loc = CGF.getLoc(Ops.Loc); + + mlir::Type lhsTy = lhs.getType(); + mlir::Type rhsTy = rhs.getType(); + + auto lhsPtrTy = mlir::dyn_cast(lhsTy); + auto rhsPtrTy = mlir::dyn_cast(rhsTy); + + if (lhsPtrTy && rhsPtrTy) { + auto lhsAS = lhsPtrTy.getAddrSpace(); + auto rhsAS = rhsPtrTy.getAddrSpace(); + + if (lhsAS != rhsAS) { + // Different address spaces → use addrspacecast + rhs = Builder.createAddrSpaceCast(rhs, lhsPtrTy); + } else if (lhsPtrTy != rhsPtrTy) { + // Same addrspace but different pointee/type → bitcast is fine + rhs = Builder.createBitcast(rhs, lhsPtrTy); + } + } // Do the raw subtraction part. // // TODO(cir): note for LLVM lowering out of this; when expanding this into @@ -1523,7 +1544,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) { // See more in `EmitSub` in CGExprScalar.cpp. assert(!cir::MissingFeatures::llvmLoweringPtrDiffConsidersPointee()); return cir::PtrDiffOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.PtrDiffTy, - Ops.LHS, Ops.RHS); + lhs, rhs); } mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) { diff --git a/clang/test/CIR/CodeGen/HIP/ptr-diff.cpp b/clang/test/CIR/CodeGen/HIP/ptr-diff.cpp new file mode 100644 index 000000000000..10cb3832b00a --- /dev/null +++ b/clang/test/CIR/CodeGen/HIP/ptr-diff.cpp @@ -0,0 +1,60 @@ +#include "cuda.h" + +// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \ +// RUN: -fcuda-is-device -fhip-new-launch-api \ +// RUN: -I%S/../Inputs/ -emit-cir %s -o %t.ll +// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.ll %s + +// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \ +// RUN: -fcuda-is-device -fhip-new-launch-api \ +// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll +// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s + +// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip \ +// RUN: -fcuda-is-device -fhip-new-launch-api \ +// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll +// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s + +__device__ int ptr_diff() { + const char c_str[] = "c-string"; + const char* len = c_str; + return c_str - len; +} + + +// CIR-DEVICE: %[[#LenLocalAddr:]] = cir.alloca !cir.ptr, !cir.ptr>, ["len", init] +// CIR-DEVICE: %[[#GlobalPtr:]] = cir.get_global @_ZZ8ptr_diffvE5c_str : !cir.ptr, addrspace(offload_constant)> +// CIR-DEVICE: %[[#CastDecay:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr, addrspace(offload_constant)> +// CIR-DEVICE: %[[#LenLocalAddrCast:]] = cir.cast bitcast %[[#LenLocalAddr]] : !cir.ptr> -> !cir.ptr> +// CIR-DEVICE: cir.store align(8) %[[#CastDecay]], %[[#LenLocalAddrCast]] : !cir.ptr, !cir.ptr> +// CIR-DEVICE: %[[#CStr:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr, addrspace(offload_constant)> -> !cir.ptr +// CIR-DEVICE: %[[#LoadedLenAddr:]] = cir.load align(8) %[[#LenLocalAddr]] : !cir.ptr>, !cir.ptr loc(#loc7) +// CIR-DEVICE: %[[#AddrCast:]] = cir.cast address_space %[[#LoadedLenAddr]] : !cir.ptr -> !cir.ptr +// CIR-DEVICE: %[[#DIFF:]] = cir.ptr_diff %[[#CStr]], %[[#AddrCast]] : !cir.ptr + +// LLVM-DEVICE: define dso_local i32 @_Z8ptr_diffv() +// LLVM-DEVICE: %[[#GlobalPtrAddr:]] = alloca i32, i64 1, align 4, addrspace(5) +// LLVM-DEVICE: %[[#GlobalPtrCast:]] = addrspacecast ptr addrspace(5) %[[#GlobalPtrAddr]] to ptr +// LLVM-DEVICE: %[[#LenLocalAddr:]] = alloca ptr, i64 1, align 8, addrspace(5) +// LLVM-DEVICE: %[[#LenLocalAddrCast:]] = addrspacecast ptr addrspace(5) %[[#LenLocalAddr]] to ptr +// LLVM-DEVICE: store ptr addrspace(4) @_ZZ8ptr_diffvE5c_str, ptr %[[#LenLocalAddrCast]], align 8 +// LLVM-DEVICE: %[[#LoadedAddr:]] = load ptr, ptr %[[#LenLocalAddrCast]], align 8 +// LLVM-DEVICE: %[[#CastedVal:]] = addrspacecast ptr %[[#LoadedAddr]] to ptr addrspace(4) +// LLVM-DEVICE: %[[#IntVal:]] = ptrtoint ptr addrspace(4) %[[#CastedVal]] to i64 +// LLVM-DEVICE: %[[#SubVal:]] = sub i64 ptrtoint (ptr addrspace(4) @_ZZ8ptr_diffvE5c_str to i64), %[[#IntVal]] + +// OGCG-DEVICE: define dso_local noundef i32 @_Z8ptr_diffv() #0 +// OGCG-DEVICE: %[[RETVAL:.*]] = alloca i32, align 4, addrspace(5) +// OGCG-DEVICE: %[[C_STR:.*]] = alloca [9 x i8], align 1, addrspace(5) +// OGCG-DEVICE: %[[LEN:.*]] = alloca ptr, align 8, addrspace(5) +// OGCG-DEVICE: %[[RETVAL_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RETVAL]] to ptr +// OGCG-DEVICE: %[[C_STR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[C_STR]] to ptr +// OGCG-DEVICE: %[[LEN_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[LEN]] to ptr +// OGCG-DEVICE: %[[ARRAYDECAY:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0 +// OGCG-DEVICE: store ptr %[[ARRAYDECAY]], ptr %[[LEN_ASCAST]], align 8 +// OGCG-DEVICE: %[[ARRAYDECAY1:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0 +// OGCG-DEVICE: %[[LOADED:.*]] = load ptr, ptr %[[LEN_ASCAST]], align 8 +// OGCG-DEVICE: %[[LHS:.*]] = ptrtoint ptr %[[ARRAYDECAY1]] to i64 +// OGCG-DEVICE: %[[RHS:.*]] = ptrtoint ptr %[[LOADED]] to i64 +// OGCG-DEVICE: %[[SUB:.*]] = sub i64 %[[LHS]], %[[RHS]] +// OGCG-DEVICE: %[[CONV:.*]] = trunc i64 %[[SUB]] to i32