Skip to content

Commit 4c067c9

Browse files
authored
[CIR][CIRGen] Fix builtin IIT Integer signedness using AST information (#1872)
1 parent e1f3a03 commit 4c067c9

File tree

3 files changed

+100
-4
lines changed

3 files changed

+100
-4
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct MissingFeatures {
7474
// GNU vectors are done, but other kinds of vectors haven't been implemented.
7575
static bool scalableVectors() { return false; }
7676
static bool vectorConstants() { return false; }
77+
static bool vectorToX86AmxCasting() { return false; }
7778

7879
// Address space related
7980
static bool addressSpace() { return false; }

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,38 @@ decodeFixedType(ArrayRef<llvm::Intrinsic::IITDescriptor> &infos,
532532
}
533533

534534
// llvm::Intrinsics accepts only LLVMContext. We need to reimplement it here.
535+
/// Helper function to correct integer signedness for intrinsic arguments.
536+
/// IIT always returns signed integers, but the actual intrinsic may expect
537+
/// unsigned integers based on the AST FunctionDecl parameter types.
538+
static mlir::Type getIntrinsicArgumentTypeFromAST(mlir::Type iitType,
539+
const CallExpr *E,
540+
unsigned argIndex,
541+
mlir::MLIRContext *context) {
542+
// If it's not an integer type, return as-is
543+
auto intTy = dyn_cast<cir::IntType>(iitType);
544+
if (!intTy)
545+
return iitType;
546+
547+
// Get the FunctionDecl from the CallExpr
548+
const FunctionDecl *FD = nullptr;
549+
if (const auto *DRE =
550+
dyn_cast<DeclRefExpr>(E->getCallee()->IgnoreImpCasts())) {
551+
FD = dyn_cast<FunctionDecl>(DRE->getDecl());
552+
}
553+
554+
// If we have FunctionDecl and this argument exists, check its signedness
555+
if (FD && argIndex < FD->getNumParams()) {
556+
QualType paramType = FD->getParamDecl(argIndex)->getType();
557+
if (paramType->isUnsignedIntegerType()) {
558+
// Create unsigned version of the type
559+
return IntType::get(context, intTy.getWidth(), /*isSigned=*/false);
560+
}
561+
}
562+
563+
// Default: keep IIT type (signed)
564+
return iitType;
565+
}
566+
535567
static cir::FuncType getIntrinsicType(mlir::MLIRContext *context,
536568
llvm::Intrinsic::ID id) {
537569
using namespace llvm::Intrinsic;
@@ -2744,12 +2776,20 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
27442776

27452777
SmallVector<mlir::Value> args;
27462778
for (unsigned i = 0; i < E->getNumArgs(); i++) {
2747-
mlir::Value arg = emitScalarOrConstFoldImmArg(iceArguments, i, E);
2748-
mlir::Type argType = arg.getType();
2749-
if (argType != intrinsicType.getInput(i))
2779+
mlir::Value argValue = emitScalarOrConstFoldImmArg(iceArguments, i, E);
2780+
// If the intrinsic arg type is different from the builtin arg type
2781+
// we need to do a bit cast.
2782+
mlir::Type argType = argValue.getType();
2783+
mlir::Type expectedTy = intrinsicType.getInput(i);
2784+
2785+
// Use helper to get the correct integer type based on AST signedness
2786+
mlir::Type correctedExpectedTy =
2787+
getIntrinsicArgumentTypeFromAST(expectedTy, E, i, &getMLIRContext());
2788+
2789+
if (argType != correctedExpectedTy)
27502790
llvm_unreachable("NYI");
27512791

2752-
args.push_back(arg);
2792+
args.push_back(argValue);
27532793
}
27542794

27552795
auto intrinsicCall = builder.create<cir::LLVMIntrinsicCallOp>(
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_70 \
2+
// RUN: -fcuda-is-device -target-feature +ptx60 \
3+
// RUN: -emit-cir -o - -x cuda %s \
4+
// RUN: | FileCheck -check-prefix=CIR %s
5+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
6+
// RUN: -fcuda-is-device -target-feature +ptx65 \
7+
// RUN: -emit-cir -o - -x cuda %s \
8+
// RUN: | FileCheck -check-prefix=CIR %s
9+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
10+
// RUN: -fcuda-is-device -target-feature +ptx70 \
11+
// RUN: -emit-cir -o - -x cuda %s \
12+
// RUN: | FileCheck -check-prefix=CIR %s
13+
14+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_70 \
15+
// RUN: -fcuda-is-device -target-feature +ptx60 \
16+
// RUN: -emit-llvm -o - -x cuda %s \
17+
// RUN: | FileCheck -check-prefix=LLVM %s
18+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
19+
// RUN: -fcuda-is-device -target-feature +ptx65 \
20+
// RUN: -emit-llvm -o - -x cuda %s \
21+
// RUN: | FileCheck -check-prefix=LLVM %s
22+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
23+
// RUN: -fcuda-is-device -target-feature +ptx70 \
24+
// RUN: -emit-llvm -o - -x cuda %s \
25+
// RUN: | FileCheck -check-prefix=LLVM %s
26+
27+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_70 \
28+
// RUN: -fcuda-is-device -target-feature +ptx60 \
29+
// RUN: -emit-llvm -o - -x cuda %s \
30+
// RUN: | FileCheck -check-prefix=OGCHECK %s
31+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 \
32+
// RUN: -fcuda-is-device -target-feature +ptx65 \
33+
// RUN: -emit-llvm -o - -x cuda %s \
34+
// RUN: | FileCheck -check-prefix=OGCHECK %s
35+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 \
36+
// RUN: -fcuda-is-device -target-feature +ptx70 \
37+
// RUN: -emit-llvm -o - -x cuda %s \
38+
// RUN: | FileCheck -check-prefix=OGCHECK %s
39+
40+
#define __device__ __attribute__((device))
41+
#define __global__ __attribute__((global))
42+
#define __shared__ __attribute__((shared))
43+
#define __constant__ __attribute__((constant))
44+
45+
typedef unsigned long long uint64_t;
46+
47+
__device__ void nvvm_sync(unsigned mask, int i, float f, int a, int b,
48+
bool pred, uint64_t i64) {
49+
50+
// CIR: cir.llvm.intrinsic "nvvm.bar.warp.sync" {{.*}} : (!u32i)
51+
// LLVM: call void @llvm.nvvm.bar.warp.sync(i32
52+
// OGCHECK: call void @llvm.nvvm.bar.warp.sync(i32
53+
__nvvm_bar_warp_sync(mask);
54+
55+
}

0 commit comments

Comments
 (0)