Skip to content

Commit 2367c0a

Browse files
committed
[flang][cuda] Use the nvvm.vote.sync op for all and any
1 parent cd2f85a commit 2367c0a

File tree

3 files changed

+19
-46
lines changed

3 files changed

+19
-46
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Runtime/iostat-consts.h"
2020
#include "mlir/Dialect/Complex/IR/Complex.h"
2121
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2223
#include "mlir/Dialect/Math/IR/Math.h"
2324
#include <optional>
2425

@@ -448,9 +449,8 @@ struct IntrinsicLibrary {
448449
llvm::ArrayRef<fir::ExtendedValue> args);
449450
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
450451
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
451-
mlir::Value genVoteAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
452-
mlir::Value genVoteAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
453-
mlir::Value genVoteBallotSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
452+
template <mlir::NVVM::VoteSyncKind kind>
453+
mlir::Value genVoteSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
454454

455455
/// Implement all conversion functions like DBLE, the first argument is
456456
/// the value to convert. There may be an additional KIND arguments that

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
#include "mlir/Dialect/Complex/IR/Complex.h"
4949
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5050
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
51-
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
5251
#include "mlir/Dialect/Math/IR/Math.h"
5352
#include "mlir/Dialect/Vector/IR/VectorOps.h"
5453
#include "llvm/Support/CommandLine.h"
@@ -262,7 +261,7 @@ static constexpr IntrinsicHandler handlers[]{
262261
{{{"mask", asAddr}, {"dim", asValue}}},
263262
/*isElemental=*/false},
264263
{"all_sync",
265-
&I::genVoteAllSync,
264+
&I::genVoteSync<mlir::NVVM::VoteSyncKind::all>,
266265
{{{"mask", asValue}, {"pred", asValue}}},
267266
/*isElemental=*/false},
268267
{"allocated",
@@ -275,7 +274,7 @@ static constexpr IntrinsicHandler handlers[]{
275274
{{{"mask", asAddr}, {"dim", asValue}}},
276275
/*isElemental=*/false},
277276
{"any_sync",
278-
&I::genVoteAnySync,
277+
&I::genVoteSync<mlir::NVVM::VoteSyncKind::any>,
279278
{{{"mask", asValue}, {"pred", asValue}}},
280279
/*isElemental=*/false},
281280
{"asind", &I::genAsind},
@@ -341,7 +340,7 @@ static constexpr IntrinsicHandler handlers[]{
341340
{"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
342341
{"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
343342
{"ballot_sync",
344-
&I::genVoteBallotSync,
343+
&I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>,
345344
{{{"mask", asValue}, {"pred", asValue}}},
346345
/*isElemental=*/false},
347346
{"bessel_jn",
@@ -6579,46 +6578,20 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
65796578
return value;
65806579
}
65816580

6582-
static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
6583-
llvm::StringRef funcName, mlir::Type resTy,
6584-
llvm::ArrayRef<mlir::Value> args) {
6585-
mlir::MLIRContext *context = builder.getContext();
6586-
mlir::Type i32Ty = builder.getI32Type();
6587-
mlir::Type i1Ty = builder.getI1Type();
6588-
mlir::FunctionType ftype =
6589-
mlir::FunctionType::get(context, {i32Ty, i1Ty}, {resTy});
6590-
auto funcOp = builder.createFunction(loc, funcName, ftype);
6591-
llvm::SmallVector<mlir::Value> filteredArgs;
6592-
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
6593-
}
6594-
6595-
// ALL_SYNC
6596-
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
6597-
llvm::ArrayRef<mlir::Value> args) {
6598-
assert(args.size() == 2);
6599-
return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync",
6600-
builder.getI1Type(), args);
6601-
}
6602-
6603-
// ANY_SYNC
6604-
mlir::Value IntrinsicLibrary::genVoteAnySync(mlir::Type resultType,
6605-
llvm::ArrayRef<mlir::Value> args) {
6606-
assert(args.size() == 2);
6607-
return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync",
6608-
builder.getI1Type(), args);
6609-
}
6610-
6611-
// BALLOT_SYNC
6612-
mlir::Value
6613-
IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
6614-
llvm::ArrayRef<mlir::Value> args) {
6581+
// ALL_SYNC, ANY_SYNC, BALLOT_SYNC
6582+
template <mlir::NVVM::VoteSyncKind kind>
6583+
mlir::Value IntrinsicLibrary::genVoteSync(mlir::Type resultType,
6584+
llvm::ArrayRef<mlir::Value> args) {
66156585
assert(args.size() == 2);
66166586
mlir::Value arg1 =
66176587
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
6618-
return builder
6619-
.create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
6620-
mlir::NVVM::VoteSyncKind::ballot)
6621-
.getResult();
6588+
mlir::Type resTy = kind == mlir::NVVM::VoteSyncKind::ballot
6589+
? builder.getI32Type()
6590+
: builder.getI1Type();
6591+
auto voteRes =
6592+
builder.create<mlir::NVVM::VoteSyncOp>(loc, resTy, args[0], arg1, kind)
6593+
.getResult();
6594+
return builder.create<fir::ConvertOp>(loc, resultType, voteRes);
66226595
}
66236596

66246597
// MATCH_ANY_SYNC

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,8 @@ attributes(device) subroutine testVote()
301301
end subroutine
302302

303303
! CHECK-LABEL: func.func @_QPtestvote()
304-
! CHECK: fir.call @llvm.nvvm.vote.all.sync
305-
! CHECK: fir.call @llvm.nvvm.vote.any.sync
304+
! CHECK: %{{.*}} = nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
305+
! CHECK: %{{.*}} = nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
306306
! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
307307

308308
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)

0 commit comments

Comments
 (0)