Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,13 @@ def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">;
def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">;
def ReduxKindFmin : I32EnumAttrCase<"FMIN", 9, "fmin">;
def ReduxKindFmax : I32EnumAttrCase<"FMAX", 10, "fmax">;

/// Enum attribute of the different kinds.
def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
[ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr,
ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
ReduxKindUmax, ReduxKindUmin, ReduxKindXor, ReduxKindFmin, ReduxKindFmax]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
Expand All @@ -273,9 +275,24 @@ def NVVM_ReduxOp :
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$val,
ReduxKindAttr:$kind,
I32:$mask_and_clamp)> {
I32:$mask_and_clamp,
DefaultValuedAttr<BoolAttr, "false">:$abs,
DefaultValuedAttr<BoolAttr, "false">:$nan)> {
let summary = "Redux Sync Op";
let description = [{
`redux.sync` performs a reduction operation `kind` of the 32 bit source
register across all non-exited threads in the membermask.

The `abs` and `nan` attributes can be used in the case of f32 input type,
where the `abs` attribute causes the absolute value of the input to be used
in the reduction operation, and the `nan` attribute causes the reduction
operation to return NaN if any of the inputs to participating threads are
NaN.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync)
}];
string llvmBuilder = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a doc-string for this op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in the latest revision. Thanks!

auto intId = getReduxIntrinsicId($_resultType, $kind);
auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan);
$res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
}];
let assemblyFormat = [{
Expand Down
16 changes: 14 additions & 2 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;

#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
: llvm::Intrinsic::nvvm_redux_sync_f##op##abs

#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)

static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
NVVM::ReduxKind kind) {
if (!resultType->isIntegerTy(32))
NVVM::ReduxKind kind,
bool hasAbs, bool hasNaN) {
if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
llvm_unreachable("unsupported data type for redux");

switch (kind) {
Expand All @@ -47,6 +55,10 @@ static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
return llvm::Intrinsic::nvvm_redux_sync_max;
case NVVM::ReduxKind::MIN:
return llvm::Intrinsic::nvvm_redux_sync_min;
case NVVM::ReduxKind::FMIN:
return GET_REDUX_F32_ID(min, hasAbs, hasNaN);
case NVVM::ReduxKind::FMAX:
return GET_REDUX_F32_ID(max, hasAbs, hasNaN);
}
llvm_unreachable("unknown redux kind");
}
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/LLVMIR/nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,25 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
llvm.return %r1 : i32
}

llvm.func @redux_sync_f32(%value: f32, %offset: i32) -> f32 {
// CHECK: nvvm.redux.sync fmin %{{.*}}
%r1 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
// CHECK: nvvm.redux.sync fmin %{{.*}}
%r2 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
// CHECK: nvvm.redux.sync fmin %{{.*}}
%r3 = nvvm.redux.sync fmin %value, %offset {NaN = true}: f32 -> f32
// CHECK: nvvm.redux.sync fmin %{{.*}}
%r4 = nvvm.redux.sync fmin %value, %offset {abs = true, NaN = true}: f32 -> f32
// CHECK: nvvm.redux.sync fmax %{{.*}}
%r5 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
// CHECK: nvvm.redux.sync fmax %{{.*}}
%r6 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
// CHECK: nvvm.redux.sync fmax %{{.*}}
%r7 = nvvm.redux.sync fmax %value, %offset {NaN = true}: f32 -> f32
// CHECK: nvvm.redux.sync fmax %{{.*}}
%r8 = nvvm.redux.sync fmax %value, %offset {abs = true, NaN = true}: f32 -> f32
llvm.return %r1 : f32
}

// -----

Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,46 @@ llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
llvm.return
}

// -----
// CHECK-LABEL: @nvvm_redux_sync
llvm.func @nvvm_redux_sync(%value: i32, %offset: i32) {
// CHECK: call i32 @llvm.nvvm.redux.sync.add(i32 %{{.*}}, i32 %{{.*}})
%0 = nvvm.redux.sync add %value, %offset: i32 -> i32
// CHECK: call i32 @llvm.nvvm.redux.sync.umax(i32 %{{.*}}, i32 %{{.*}})
%1 = nvvm.redux.sync umax %value, %offset: i32 -> i32
// CHECK: call i32 @llvm.nvvm.redux.sync.umin(i32 %{{.*}}, i32 %{{.*}})
%2 = nvvm.redux.sync umin %value, %offset: i32 -> i32
// CHECK: call i32 @llvm.nvvm.redux.sync.and(i32 %{{.*}}, i32 %{{.*}})
%3 = nvvm.redux.sync and %value, %offset: i32 -> i32
// CHECK: call i32 @llvm.nvvm.redux.sync.or(i32 %{{.*}}, i32 %{{.*}})
%4 = nvvm.redux.sync or %value, %offset: i32 -> i32
// CHECK: call i32 @llvm.nvvm.redux.sync.xor(i32 %{{.*}}, i32 %{{.*}})
%5 = nvvm.redux.sync xor %value, %offset: i32 -> i32
// CHECK: call i32 @llvm.nvvm.redux.sync.max(i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.redux.sync max %value, %offset: i32 -> i32
// CHECK: call i32 @llvm.nvvm.redux.sync.min(i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.redux.sync min %value, %offset: i32 -> i32
llvm.return
}

// CHECK-LABEL: @nvvm_redux_sync_f32
llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
// CHECK: call float @llvm.nvvm.redux.sync.fmin(float %{{.*}}, i32 %{{.*}})
%0 = nvvm.redux.sync fmin %value, %offset: f32 -> f32
// CHECK: call float @llvm.nvvm.redux.sync.fmin.abs(float %{{.*}}, i32 %{{.*}})
%1 = nvvm.redux.sync fmin %value, %offset {abs = true}: f32 -> f32
// CHECK: call float @llvm.nvvm.redux.sync.fmin.NaN(float %{{.*}}, i32 %{{.*}})
%2 = nvvm.redux.sync fmin %value, %offset {nan = true}: f32 -> f32
// CHECK: call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %{{.*}}, i32 %{{.*}})
%3 = nvvm.redux.sync fmin %value, %offset {abs = true, nan = true}: f32 -> f32
// CHECK: call float @llvm.nvvm.redux.sync.fmax(float %{{.*}}, i32 %{{.*}})
%4 = nvvm.redux.sync fmax %value, %offset: f32 -> f32
// CHECK: call float @llvm.nvvm.redux.sync.fmax.abs(float %{{.*}}, i32 %{{.*}})
%5 = nvvm.redux.sync fmax %value, %offset {abs = true}: f32 -> f32
// CHECK: call float @llvm.nvvm.redux.sync.fmax.NaN(float %{{.*}}, i32 %{{.*}})
%6 = nvvm.redux.sync fmax %value, %offset {nan = true}: f32 -> f32
// CHECK: call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %{{.*}}, i32 %{{.*}})
%7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
llvm.return
}