diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 0de5a87e72c3f..df43ed036d3a5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -257,11 +257,13 @@ def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">; def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">; def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">; def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">; +def ReduxKindFmin : I32EnumAttrCase<"FMIN", 9, "fmin">; +def ReduxKindFmax : I32EnumAttrCase<"FMAX", 10, "fmax">; /// Enum attribute of the different kinds. def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind", [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, - ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> { + ReduxKindUmax, ReduxKindUmin, ReduxKindXor, ReduxKindFmin, ReduxKindFmax]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } @@ -273,9 +275,24 @@ def NVVM_ReduxOp : Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_Type:$val, ReduxKindAttr:$kind, - I32:$mask_and_clamp)> { + I32:$mask_and_clamp, + DefaultValuedAttr:$abs, + DefaultValuedAttr:$nan)> { + let summary = "Redux Sync Op"; + let description = [{ + `redux.sync` performs a reduction operation `kind` of the 32 bit source + register across all non-exited threads in the membermask. + + The `abs` and `nan` attributes can be used in the case of f32 input type, + where the `abs` attribute causes the absolute value of the input to be used + in the reduction operation, and the `nan` attribute causes the reduction + operation to return NaN if any of the inputs to participating threads are + NaN. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync) + }]; string llvmBuilder = [{ - auto intId = getReduxIntrinsicId($_resultType, $kind); + auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan); $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp}); }]; let assemblyFormat = [{ diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 8b13735774663..6d34cf71bb780 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -25,9 +25,17 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; +#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \ + hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \ + : llvm::Intrinsic::nvvm_redux_sync_f##op##abs + +#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \ + hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN) + static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, - NVVM::ReduxKind kind) { - if (!resultType->isIntegerTy(32)) + NVVM::ReduxKind kind, + bool hasAbs, bool hasNaN) { + if (!(resultType->isIntegerTy(32) || resultType->isFloatTy())) llvm_unreachable("unsupported data type for redux"); switch (kind) { @@ -47,6 +55,10 @@ static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, return llvm::Intrinsic::nvvm_redux_sync_max; case NVVM::ReduxKind::MIN: return llvm::Intrinsic::nvvm_redux_sync_min; + case NVVM::ReduxKind::FMIN: + return GET_REDUX_F32_ID(min, hasAbs, hasNaN); + case NVVM::ReduxKind::FMAX: + return GET_REDUX_F32_ID(max, hasAbs, hasNaN); } llvm_unreachable("unknown redux kind"); } diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index dd54acd1e317e..85998d4e66254 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -411,6 +411,25 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 { llvm.return %r1 : i32 } +llvm.func @redux_sync_f32(%value: f32, %offset: i32) -> f32 { + // CHECK: nvvm.redux.sync fmin %{{.*}} + %r1 = nvvm.redux.sync fmin %value, %offset: f32 -> f32 + // CHECK: nvvm.redux.sync fmin %{{.*}} + %r2 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32 + // CHECK: nvvm.redux.sync fmin %{{.*}} + %r3 = nvvm.redux.sync fmin %value, %offset {NaN = true}: f32 -> f32 + // CHECK: nvvm.redux.sync fmin %{{.*}} + %r4 = nvvm.redux.sync fmin %value, %offset {abs = true, NaN = true}: f32 -> f32 + // CHECK: nvvm.redux.sync fmax %{{.*}} + %r5 = nvvm.redux.sync fmax %value, %offset: f32 -> f32 + // CHECK: nvvm.redux.sync fmax %{{.*}} + %r6 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32 + // CHECK: nvvm.redux.sync fmax %{{.*}} + %r7 = nvvm.redux.sync fmax %value, %offset {NaN = true}: f32 -> f32 + // CHECK: nvvm.redux.sync fmax %{{.*}} + %r8 = nvvm.redux.sync fmax %value, %offset {abs = true, NaN = true}: f32 -> f32 + llvm.return %r1 : f32 +} // ----- diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 5ab593452ab66..d11558698d860 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -780,3 +780,46 @@ llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> llvm.return } + +// ----- +// CHECK-LABEL: @nvvm_redux_sync +llvm.func @nvvm_redux_sync(%value: i32, %offset: i32) { + // CHECK: call i32 @llvm.nvvm.redux.sync.add(i32 %{{.*}}, i32 %{{.*}}) + %0 = nvvm.redux.sync add %value, %offset: i32 -> i32 + // CHECK: call i32 @llvm.nvvm.redux.sync.umax(i32 %{{.*}}, i32 %{{.*}}) + %1 = nvvm.redux.sync umax %value, %offset: i32 -> i32 + // CHECK: call i32 @llvm.nvvm.redux.sync.umin(i32 %{{.*}}, i32 %{{.*}}) + %2 = nvvm.redux.sync umin %value, %offset: i32 -> i32 + // CHECK: call i32 @llvm.nvvm.redux.sync.and(i32 %{{.*}}, i32 %{{.*}}) + %3 = nvvm.redux.sync and %value, %offset: i32 -> i32 + // CHECK: call i32 @llvm.nvvm.redux.sync.or(i32 %{{.*}}, i32 %{{.*}}) + %4 = nvvm.redux.sync or %value, %offset: i32 -> i32 + // CHECK: call i32 @llvm.nvvm.redux.sync.xor(i32 %{{.*}}, i32 %{{.*}}) + %5 = nvvm.redux.sync xor %value, %offset: i32 -> i32 + // CHECK: call i32 @llvm.nvvm.redux.sync.max(i32 %{{.*}}, i32 %{{.*}}) + %6 = nvvm.redux.sync max %value, %offset: i32 -> i32 + // CHECK: call i32 @llvm.nvvm.redux.sync.min(i32 %{{.*}}, i32 %{{.*}}) + %7 = nvvm.redux.sync min %value, %offset: i32 -> i32 + llvm.return +} + +// CHECK-LABEL: @nvvm_redux_sync_f32 +llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) { + // CHECK: call float @llvm.nvvm.redux.sync.fmin(float %{{.*}}, i32 %{{.*}}) + %0 = nvvm.redux.sync fmin %value, %offset: f32 -> f32 + // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs(float %{{.*}}, i32 %{{.*}}) + %1 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32 + // CHECK: call float @llvm.nvvm.redux.sync.fmin.NaN(float %{{.*}}, i32 %{{.*}}) + %2 = nvvm.redux.sync fmin %value, %offset {nan = true}: f32 -> f32 + // CHECK: call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %{{.*}}, i32 %{{.*}}) + %3 = nvvm.redux.sync fmin %value, %offset {abs = true, nan = true}: f32 -> f32 + // CHECK: call float @llvm.nvvm.redux.sync.fmax(float %{{.*}}, i32 %{{.*}}) + %4 = nvvm.redux.sync fmax %value, %offset: f32 -> f32 + // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs(float %{{.*}}, i32 %{{.*}}) + %5 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32 + // CHECK: call float @llvm.nvvm.redux.sync.fmax.NaN(float %{{.*}}, i32 %{{.*}}) + %6 = nvvm.redux.sync fmax %value, %offset {nan = true}: f32 -> f32 + // CHECK: call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %{{.*}}, i32 %{{.*}}) + %7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32 + llvm.return +}