@@ -3445,25 +3445,28 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
34453445}
34463446
34473447//===----------------------------------------------------------------------===//
3448- // NVVM dot.accumulate.4way Op
3448+ // NVVM dot.accumulate Ops
34493449//===----------------------------------------------------------------------===//
34503450
3451- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
3451+ def DotAccumulateS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452+ def DotAccumulateU8 : I32EnumAttrCase<"U8", 0, "u8">;
3453+ def DotAccumulateS16 : I32EnumAttrCase<"S16", 2, "s16">;
3454+ def DotAccumulateU16 : I32EnumAttrCase<"U16", 3, "u16">;
34533455
3454- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
3455- "NVVM DotAccumulate4WayType",
3456- [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
3456+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
3457+ "NVVM DotAccumulateType",
3458+ [DotAccumulateS8, DotAccumulateU8,
3459+ DotAccumulateS16, DotAccumulateU16]> {
34573460 let cppNamespace = "::mlir::NVVM";
34583461 let genSpecializedAttr = 0;
34593462}
34603463
3461- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3464+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
34623465 let assemblyFormat = "`<` $value `>`";
34633466}
34643467
34653468def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3466- let summary = "Four-way byte dot product-accumulate instruction. ";
3469+ let summary = "Four-way byte dot product-accumulate instruction";
34673470 let description = [{
34683471 Performs a four-way byte dot-product which is accumulated in a 32-bit
34693472 result.
@@ -3481,11 +3484,13 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
34813484 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
34823485 }];
34833486
3487+ let hasVerifier = 1;
3488+
34843489 let arguments = (ins
34853490 VectorOfLengthAndType<[4], [I8]>:$a,
3486- DotAccumulate4WayTypeAttr :$a_type,
3491+ DotAccumulateTypeAttr :$a_type,
34873492 VectorOfLengthAndType<[4], [I8]>:$b,
3488- DotAccumulate4WayTypeAttr :$b_type,
3493+ DotAccumulateTypeAttr :$b_type,
34893494 I32:$c
34903495 );
34913496
@@ -3495,8 +3500,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
34953500
34963501 let extraClassDeclaration = [{
34973502 static llvm::Intrinsic::ID
3498- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3499- NVVM::DotAccumulate4WayType b_type);
3503+ getIntrinsicID(NVVM::DotAccumulateType a_type,
3504+ NVVM::DotAccumulateType b_type);
35003505 llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
35013506 }];
35023507
@@ -3508,6 +3513,86 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
35083513 }];
35093514}
35103515
3516+ def DotAccumulate2WayModeLo : I32EnumAttrCase<"LO", 0, "lo">;
3517+ def DotAccumulate2WayModeHi : I32EnumAttrCase<"HI", 1, "hi">;
3518+
3519+ def DotAccumulate2WayMode : I32EnumAttr<"DotAccumulate2WayMode",
3520+ "NVVM DotAccumulate2WayMode",
3521+ [DotAccumulate2WayModeLo, DotAccumulate2WayModeHi]> {
3522+ let cppNamespace = "::mlir::NVVM";
3523+ let genSpecializedAttr = 0;
3524+ }
3525+
3526+ def DotAccumulate2WayModeAttr : EnumAttr<NVVM_Dialect, DotAccumulate2WayMode, "dot_accumulate_2way_mode"> {
3527+ let assemblyFormat = "$value";
3528+ }
3529+
3530+ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
3531+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
3532+ let description = [{
3533+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
3534+ 32-bit result.
3535+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
3536+ of four 8-bit elements between which the dot product is computed.
3537+
3538+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
3539+ and `b` respectively.
3540+ If `a_type` is `s16`, then the elements in `a` are sign-extended to
3541+ 32-bit before the dot product is computed.
3542+ If `a_type` is `u16`, then the elements in `a` are zero-extended to
3543+ 32-bit instead.
3544+ If `b_type` is `s8`, then the elements in `b` are sign-extended to
3545+ 32-bit before the dot product is computed.
3546+ If `b_type` is `u8`, then the elements in `b` are zero-extended to
3547+ 32-bit instead.
3548+
3549+ The `hi` boolean attribute specifies which two bytes of `b` are used for
3550+ the dot product. If `hi` is true, then the dot product is computed between
3551+ `a` and elements at indices 2 and 3 of `b`. If `hi` is false, then the dot
3552+ product is computed between `a` and elements at indices 0 and 1 of `b`.
3553+ By default, `hi` is false.
3554+
3555+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
3556+ treated as holding a signed integer if any of `a_type` or `b_type` is
3557+ signed.
3558+
3559+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
3560+ }];
3561+
3562+ let hasVerifier = 1;
3563+
3564+ let arguments = (ins
3565+ VectorOfLengthAndType<[2], [I16]>:$a,
3566+ DotAccumulateTypeAttr:$a_type,
3567+ VectorOfLengthAndType<[4], [I8]>:$b,
3568+ DotAccumulateTypeAttr:$b_type,
3569+ I32:$c,
3570+ DefaultValuedAttr<BoolAttr, "false">:$hi
3571+ );
3572+
3573+ let results = (outs I32:$res);
3574+
3575+ let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3576+
3577+ let extraClassDeclaration = [{
3578+ static llvm::Intrinsic::ID
3579+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3580+ llvm::IRBuilderBase &builder,
3581+ llvm::SmallVector<llvm::Value *> &args);
3582+ llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3583+ }];
3584+
3585+ string llvmBuilder = [{
3586+ llvm::SmallVector<llvm::Value *> args;
3587+
3588+ llvm::Intrinsic::ID
3589+ id = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3590+ *op, moduleTranslation, builder, args);
3591+
3592+ $res = createIntrinsicCall(builder, id, args);
3593+ }];
3594+ }
3595+
35113596//===----------------------------------------------------------------------===//
35123597// NVVM target attribute.
35133598//===----------------------------------------------------------------------===//
0 commit comments