Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3703,6 +3703,60 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
}];
}

def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
let description = [{
Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
32-bit result.
Operand `a` is a vector of two 16-bit elements and operand `b` a vector
of four 8-bit elements between which the dot product is computed.

The `a_type` and `b_type` attributes specify the type of the elements in `a`
and `b` respectively.
If `a_type` or `b_type` is `s`, then the elements in the corresponding
vector are sign-extended to 32-bit before the dot product is computed.
If `a_type` or `b_type` is `u`, then the elements in the corresponding
vector are zero-extended to 32-bit instead.

The `b_hi` boolean attribute specifies which two bytes of `b` are used for
the dot product. If `b_hi` is true, then the dot product is computed
between `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false,
then the dot product is computed between `a` and elements at indices 0 and
1 of `b`.

Operand `c` is a 32-bit integer to which the result is accumulated. It is
treated as holding a signed integer if any of `a_type` or `b_type` is
signed.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
}];

let arguments = (ins
VectorOfLengthAndType<[2], [I16]>:$a,
DotAccumulateTypeAttr:$a_type,
VectorOfLengthAndType<[4], [I8]>:$b,
DotAccumulateTypeAttr:$b_type,
I32:$c,
BoolAttr:$b_hi
);

let results = (outs I32:$res);

let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";

let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder);
}];

string llvmBuilder = [{
auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
$res = createIntrinsicCall(builder, id, args);
}];
}

//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1712,6 +1712,28 @@ NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
return {ids[type], args};
}

NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);

llvm::SmallVector<llvm::Value *> args;
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
args.push_back(builder.getInt1(curOp.getBHi()));
args.push_back(mt.lookupValue(curOp.getC()));

bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
unsigned type = (isASigned << 1) | isBSigned;
const llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_idp2a_u_u,
llvm::Intrinsic::nvvm_idp2a_u_s,
llvm::Intrinsic::nvvm_idp2a_s_u,
llvm::Intrinsic::nvvm_idp2a_s_s,
};
return {ids[type], args};
}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/LLVMIR/nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,15 @@ func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i
return
}

// CHECK-LABEL: @dot_accumulate_2way
func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
// CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = false} : vector<2xi16>, vector<4xi8>
%1 = nvvm.dot.accumulate.2way %a_vec <unsigned>, %b_vec <unsigned>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
// CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = true} : vector<2xi16>, vector<4xi8>
%3 = nvvm.dot.accumulate.2way %a_vec <signed>, %b_vec <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
return
}

// -----

// Just check these don't emit errors.
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,41 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
%3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
llvm.return
}

// -----
// CHECK-LABEL: @nvvm_dot_accumulate_2way
llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
%0 = nvvm.dot.accumulate.2way %a <unsigned>, %b <unsigned>, %c {b_hi = false} : vector<2xi16>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
%1 = nvvm.dot.accumulate.2way %a <unsigned>, %b <unsigned>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
%2 = nvvm.dot.accumulate.2way %a <signed>, %b <unsigned>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
%3 = nvvm.dot.accumulate.2way %a <signed>, %b <unsigned>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
%4 = nvvm.dot.accumulate.2way %a <unsigned>, %b <signed>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
%5 = nvvm.dot.accumulate.2way %a <unsigned>, %b <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
%6 = nvvm.dot.accumulate.2way %a <signed>, %b <signed>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
%7 = nvvm.dot.accumulate.2way %a <signed>, %b <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
llvm.return
}
Loading