@@ -3444,6 +3444,70 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
34443444 let hasVerifier = 1;
34453445}
34463446
3447+ //===----------------------------------------------------------------------===//
3448+ // NVVM dot.accumulate.4way Op
3449+ //===----------------------------------------------------------------------===//
3450+
3451+ def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452+ def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
3453+
3454+ def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
3455+ "NVVM DotAccumulate4WayType",
3456+ [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
3457+ let cppNamespace = "::mlir::NVVM";
3458+ let genSpecializedAttr = 0;
3459+ }
3460+
3461+ def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
3462+ let assemblyFormat = "`<` $value `>`";
3463+ }
3464+
3465+ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3466+ let summary = "Four-way byte dot product-accumulate instruction.";
3467+ let description = [{
3468+ Performs a four-way byte dot-product which is accumulated in a 32-bit
3469+ result.
3470+ Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3471+ computed.
3472+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
3473+ and `b` respectively.
3474+ If `a_type` or `b_type` is `s8`, then the elements in the corresponding
3475+ vector are sign-extended to 32-bit before the dot product is computed.
3476+ If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3477+ vector are zero-extended to 32-bit instead.
3478+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
3479+ treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
3480+
3481+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
3482+ }];
3483+
3484+ let arguments = (ins
3485+ VectorOfLengthAndType<[4], [I8]>:$a,
3486+ DotAccumulate4WayTypeAttr:$a_type,
3487+ VectorOfLengthAndType<[4], [I8]>:$b,
3488+ DotAccumulate4WayTypeAttr:$b_type,
3489+ I32:$c
3490+ );
3491+
3492+ let results = (outs I32:$res);
3493+
3494+ let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3495+
3496+ let extraClassDeclaration = [{
3497+ static llvm::Intrinsic::ID
3498+ getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3499+ NVVM::DotAccumulate4WayType b_type);
3500+ llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3501+ }];
3502+
3503+ string llvmBuilder = [{
3504+ llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3505+ llvm::Value* argA = op.getPackedArg($a, builder);
3506+ llvm::Value* argB = op.getPackedArg($b, builder);
3507+ $res = createIntrinsicCall(builder, id, {argA, argB, $c});
3508+ }];
3509+ }
3510+
34473511//===----------------------------------------------------------------------===//
34483512// NVVM target attribute.
34493513//===----------------------------------------------------------------------===//
0 commit comments