Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flang/include/flang/Optimizer/Support/InitFIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
Expand All @@ -37,7 +38,8 @@ namespace fir::support {
mlir::scf::SCFDialect, mlir::arith::ArithDialect, \
mlir::cf::ControlFlowDialect, mlir::func::FuncDialect, \
mlir::vector::VectorDialect, mlir::math::MathDialect, \
mlir::complex::ComplexDialect, mlir::DLTIDialect, cuf::CUFDialect
mlir::complex::ComplexDialect, mlir::DLTIDialect, cuf::CUFDialect, \
mlir::NVVM::NVVMDialect

#define FLANG_CODEGEN_DIALECT_LIST FIRCodeGenDialect, mlir::LLVM::LLVMDialect

Expand Down
21 changes: 5 additions & 16 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -6548,23 +6549,11 @@ IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
assert(args.size() == 2);
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();

llvm::StringRef funcName =
is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
mlir::MLIRContext *context = builder.getContext();
mlir::Type i32Ty = builder.getI32Type();
mlir::Type i64Ty = builder.getI64Type();
mlir::Type valTy = is32 ? i32Ty : i64Ty;
mlir::Value arg1 = args[1];
if (arg1.getType().isF32() || arg1.getType().isF64())
arg1 = builder.create<fir::ConvertOp>(loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);

mlir::FunctionType ftype =
mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
auto funcOp = builder.createFunction(loc, funcName, ftype);
llvm::SmallVector<mlir::Value> filteredArgs;
filteredArgs.push_back(args[0]);
if (args[1].getType().isF32() || args[1].getType().isF64())
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
else
filteredArgs.push_back(args[1]);
return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
return builder.create<mlir::NVVM::MatchSyncOp>(loc, resultType, args[0], arg1, mlir::NVVM::MatchSyncKind::any).getResult();
}

// MATMUL
Expand Down
10 changes: 4 additions & 6 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,10 @@ attributes(device) subroutine testMatchAny()
end subroutine

! CHECK-LABEL: func.func @_QPtestmatchany()
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
! CHECK: fir.convert %{{.*}} : (f32) -> i32
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
! CHECK: fir.convert %{{.*}} : (f64) -> i64
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32

attributes(device) subroutine testAtomic(aa, n)
integer :: aa(*)
Expand Down
Loading