-
Notifications
You must be signed in to change notification settings - Fork 182
[CIR][HIP] Proper Handling of address spaces in ptr-diff #1994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1514,7 +1514,28 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) { | |
| return emitPointerArithmetic(CGF, Ops, /*isSubtraction=*/true); | ||
|
|
||
| // Otherwise, this is a pointer subtraction | ||
|
|
||
| mlir::Value lhs = Ops.LHS; // pointer | ||
| mlir::Value rhs = Ops.RHS; // pointer | ||
| auto loc = CGF.getLoc(Ops.Loc); | ||
|
|
||
| mlir::Type lhsTy = lhs.getType(); | ||
| mlir::Type rhsTy = rhs.getType(); | ||
|
|
||
| auto lhsPtrTy = mlir::dyn_cast<cir::PointerType>(lhsTy); | ||
| auto rhsPtrTy = mlir::dyn_cast<cir::PointerType>(rhsTy); | ||
|
|
||
| if (lhsPtrTy && rhsPtrTy) { | ||
| auto lhsAS = lhsPtrTy.getAddrSpace(); | ||
| auto rhsAS = rhsPtrTy.getAddrSpace(); | ||
|
|
||
| if (lhsAS != rhsAS) { | ||
| // Different address spaces → use addrspacecast | ||
| rhs = Builder.createAddrSpaceCast(rhs, lhsPtrTy); | ||
| } else if (lhsPtrTy != rhsPtrTy) { | ||
| // Same addrspace but different pointee/type → bitcast is fine | ||
| rhs = Builder.createBitcast(rhs, lhsPtrTy); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, Can you please try I have a local patch for that function to allow passing AS, I'll upload it ASAP: ----------- clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h -----------
index b7b599e1289e..fd096ff97eb7 100644
@@ -574,7 +574,7 @@ public:
mlir::Value createPtrBitcast(mlir::Value src, mlir::Type newPointeeTy) {
assert(mlir::isa<cir::PointerType>(src.getType()) && "expected ptr src");
- return createBitcast(src, getPointerTo(newPointeeTy));
+ return createBitcast(src, getPointerTo(newPointeeTy, mlir::cast<cir::PointerType>(src.getType()).getAddrSpace()));
}In addition I have another patch to fix the case of different AS and different data type. I'll upload them ASAP and let you know.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! I took a closer look at
Happy to switch if we move toward a uniform helper style later, but for this case preserving the full Thanks again for taking a look — really appreciate the review! |
||
| } | ||
| } | ||
| // Do the raw subtraction part. | ||
| // | ||
| // TODO(cir): note for LLVM lowering out of this; when expanding this into | ||
|
|
@@ -1523,7 +1544,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) { | |
| // See more in `EmitSub` in CGExprScalar.cpp. | ||
| assert(!cir::MissingFeatures::llvmLoweringPtrDiffConsidersPointee()); | ||
| return cir::PtrDiffOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.PtrDiffTy, | ||
| Ops.LHS, Ops.RHS); | ||
| lhs, rhs); | ||
| } | ||
|
|
||
| mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| #include "cuda.h" | ||
|
|
||
| // RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \ | ||
| // RUN: -fcuda-is-device -fhip-new-launch-api \ | ||
| // RUN: -I%S/../Inputs/ -emit-cir %s -o %t.ll | ||
| // RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.ll %s | ||
|
|
||
| // RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \ | ||
| // RUN: -fcuda-is-device -fhip-new-launch-api \ | ||
| // RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll | ||
| // RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s | ||
|
|
||
| // RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip \ | ||
| // RUN: -fcuda-is-device -fhip-new-launch-api \ | ||
| // RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll | ||
| // RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s | ||
|
|
||
| __device__ int ptr_diff() { | ||
| const char c_str[] = "c-string"; | ||
| const char* len = c_str; | ||
| return c_str - len; | ||
| } | ||
|
|
||
|
|
||
| // CIR-DEVICE: %[[#LenLocalAddr:]] = cir.alloca !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>, ["len", init] | ||
| // CIR-DEVICE: %[[#GlobalPtr:]] = cir.get_global @_ZZ8ptr_diffvE5c_str : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)> | ||
| // CIR-DEVICE: %[[#CastDecay:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)> | ||
| // CIR-DEVICE: %[[#LenLocalAddrCast:]] = cir.cast bitcast %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>> -> !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>> | ||
| // CIR-DEVICE: cir.store align(8) %[[#CastDecay]], %[[#LenLocalAddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)>, !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>> | ||
| // CIR-DEVICE: %[[#CStr:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)> -> !cir.ptr<!s8i, addrspace(offload_constant)> | ||
| // CIR-DEVICE: %[[#LoadedLenAddr:]] = cir.load align(8) %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i> loc(#loc7) | ||
| // CIR-DEVICE: %[[#AddrCast:]] = cir.cast address_space %[[#LoadedLenAddr]] : !cir.ptr<!s8i> -> !cir.ptr<!s8i, addrspace(offload_constant)> | ||
| // CIR-DEVICE: %[[#DIFF:]] = cir.ptr_diff %[[#CStr]], %[[#AddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)> | ||
|
|
||
| // LLVM-DEVICE: define dso_local i32 @_Z8ptr_diffv() | ||
| // LLVM-DEVICE: %[[#GlobalPtrAddr:]] = alloca i32, i64 1, align 4, addrspace(5) | ||
| // LLVM-DEVICE: %[[#GlobalPtrCast:]] = addrspacecast ptr addrspace(5) %[[#GlobalPtrAddr]] to ptr | ||
| // LLVM-DEVICE: %[[#LenLocalAddr:]] = alloca ptr, i64 1, align 8, addrspace(5) | ||
| // LLVM-DEVICE: %[[#LenLocalAddrCast:]] = addrspacecast ptr addrspace(5) %[[#LenLocalAddr]] to ptr | ||
| // LLVM-DEVICE: store ptr addrspace(4) @_ZZ8ptr_diffvE5c_str, ptr %[[#LenLocalAddrCast]], align 8 | ||
| // LLVM-DEVICE: %[[#LoadedAddr:]] = load ptr, ptr %[[#LenLocalAddrCast]], align 8 | ||
| // LLVM-DEVICE: %[[#CastedVal:]] = addrspacecast ptr %[[#LoadedAddr]] to ptr addrspace(4) | ||
| // LLVM-DEVICE: %[[#IntVal:]] = ptrtoint ptr addrspace(4) %[[#CastedVal]] to i64 | ||
| // LLVM-DEVICE: %[[#SubVal:]] = sub i64 ptrtoint (ptr addrspace(4) @_ZZ8ptr_diffvE5c_str to i64), %[[#IntVal]] | ||
|
|
||
| // OGCG-DEVICE: define dso_local noundef i32 @_Z8ptr_diffv() #0 | ||
| // OGCG-DEVICE: %[[RETVAL:.*]] = alloca i32, align 4, addrspace(5) | ||
| // OGCG-DEVICE: %[[C_STR:.*]] = alloca [9 x i8], align 1, addrspace(5) | ||
| // OGCG-DEVICE: %[[LEN:.*]] = alloca ptr, align 8, addrspace(5) | ||
| // OGCG-DEVICE: %[[RETVAL_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RETVAL]] to ptr | ||
| // OGCG-DEVICE: %[[C_STR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[C_STR]] to ptr | ||
| // OGCG-DEVICE: %[[LEN_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[LEN]] to ptr | ||
| // OGCG-DEVICE: %[[ARRAYDECAY:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0 | ||
| // OGCG-DEVICE: store ptr %[[ARRAYDECAY]], ptr %[[LEN_ASCAST]], align 8 | ||
| // OGCG-DEVICE: %[[ARRAYDECAY1:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0 | ||
| // OGCG-DEVICE: %[[LOADED:.*]] = load ptr, ptr %[[LEN_ASCAST]], align 8 | ||
| // OGCG-DEVICE: %[[LHS:.*]] = ptrtoint ptr %[[ARRAYDECAY1]] to i64 | ||
| // OGCG-DEVICE: %[[RHS:.*]] = ptrtoint ptr %[[LOADED]] to i64 | ||
| // OGCG-DEVICE: %[[SUB:.*]] = sub i64 %[[LHS]], %[[RHS]] | ||
| // OGCG-DEVICE: %[[CONV:.*]] = trunc i64 %[[SUB]] to i32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should follow OG skeleton here, similar names if possible would be helpful: