Skip to content

Commit 1a5a837

Browse files
clementvalmikolaj-pirog
authored andcommitted
[flang][cuda] Make sure operand to syncthread function is i32 (llvm#164747)
1 parent f28ed13 commit 1a5a837

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8974,10 +8974,12 @@ IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
89748974
llvm::ArrayRef<mlir::Value> args) {
89758975
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
89768976
mlir::MLIRContext *context = builder.getContext();
8977+
mlir::Type i32 = builder.getI32Type();
89778978
mlir::FunctionType ftype =
8978-
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
8979+
mlir::FunctionType::get(context, {resultType}, {i32});
89798980
auto funcOp = builder.createFunction(loc, funcName, ftype);
8980-
return fir::CallOp::create(builder, loc, funcOp, args).getResult(0);
8981+
mlir::Value arg = builder.createConvert(loc, i32, args[0]);
8982+
return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
89818983
}
89828984

89838985
// SYNCTHREADS_COUNT
@@ -8986,10 +8988,12 @@ IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
89868988
llvm::ArrayRef<mlir::Value> args) {
89878989
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
89888990
mlir::MLIRContext *context = builder.getContext();
8991+
mlir::Type i32 = builder.getI32Type();
89898992
mlir::FunctionType ftype =
8990-
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
8993+
mlir::FunctionType::get(context, {resultType}, {i32});
89918994
auto funcOp = builder.createFunction(loc, funcName, ftype);
8992-
return fir::CallOp::create(builder, loc, funcOp, args).getResult(0);
8995+
mlir::Value arg = builder.createConvert(loc, i32, args[0]);
8996+
return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
89938997
}
89948998

89958999
// SYNCTHREADS_OR
@@ -8998,10 +9002,12 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
89989002
llvm::ArrayRef<mlir::Value> args) {
89999003
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
90009004
mlir::MLIRContext *context = builder.getContext();
9005+
mlir::Type i32 = builder.getI32Type();
90019006
mlir::FunctionType ftype =
9002-
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
9007+
mlir::FunctionType::get(context, {resultType}, {i32});
90039008
auto funcOp = builder.createFunction(loc, funcName, ftype);
9004-
return fir::CallOp::create(builder, loc, funcOp, args).getResult(0);
9009+
mlir::Value arg = builder.createConvert(loc, i32, args[0]);
9010+
return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
90059011
}
90069012

90079013
// SYNCWARP

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,20 @@ end
110110
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
111111
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
112112
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
113-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CMP]])
113+
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
114+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CONV]])
114115
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
115116
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
116117
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
117118
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
118-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CMP]]) fastmath<contract> : (i1) -> i32
119+
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
120+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CONV]]) fastmath<contract> : (i32) -> i32
119121
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
120122
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
121123
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
122124
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
123-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CMP]]) fastmath<contract> : (i1) -> i32
125+
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
126+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CONV]]) fastmath<contract> : (i32) -> i32
124127
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
125128
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
126129
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32

0 commit comments

Comments
 (0)