From 6e468171b80a43cd542ce042b937207b58014c2d Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Thu, 15 May 2025 16:41:08 +0530 Subject: [PATCH] [MLIR][NVVM] Update dot.accumulate NVVM Ops This change: - Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit dot-product accumulate operation. - Refactors the recently added dot.accumulate.4way and adds a verifier. --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 54 +++++++++++++++++++++ mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 22 +++++++++ mlir/test/Dialect/LLVMIR/nvvm.mlir | 9 ++++ mlir/test/Target/LLVMIR/nvvmir.mlir | 38 +++++++++++++++ 4 files changed, 123 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 2424e3af80d2d..596a584d485ed 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3703,6 +3703,60 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> { }]; } +def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> { + let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction"; + let description = [{ + Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a + 32-bit result. + Operand `a` is a vector of two 16-bit elements and operand `b` a vector + of four 8-bit elements between which the dot product is computed. + + The `a_type` and `b_type` attributes specify the type of the elements in `a` + and `b` respectively. + If `a_type` or `b_type` is `s`, then the elements in the corresponding + vector are sign-extended to 32-bit before the dot product is computed. + If `a_type` or `b_type` is `u`, then the elements in the corresponding + vector are zero-extended to 32-bit instead. + + The `b_hi` boolean attribute specifies which two bytes of `b` are used for + the dot product. If `b_hi` is true, then the dot product is computed + between `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false, + then the dot product is computed between `a` and elements at indices 0 and + 1 of `b`. + + Operand `c` is a 32-bit integer to which the result is accumulated. It is + treated as holding a signed integer if any of `a_type` or `b_type` is + signed. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a) + }]; + + let arguments = (ins + VectorOfLengthAndType<[2], [I16]>:$a, + DotAccumulateTypeAttr:$a_type, + VectorOfLengthAndType<[4], [I8]>:$b, + DotAccumulateTypeAttr:$b_type, + I32:$c, + BoolAttr:$b_hi + ); + + let results = (outs I32:$res); + + let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)"; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + //===----------------------------------------------------------------------===// // NVVM target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 648b6b087e592..a77ff1e32dc23 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1712,6 +1712,28 @@ NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs( return {ids[type], args}; } +NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + llvm::SmallVector args; + args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder)); + args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder)); + args.push_back(builder.getInt1(curOp.getBHi())); + args.push_back(mt.lookupValue(curOp.getC())); + + bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED; + bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED; + unsigned type = (isASigned << 1) | isBSigned; + const llvm::Intrinsic::ID ids[] = { + llvm::Intrinsic::nvvm_idp2a_u_u, + llvm::Intrinsic::nvvm_idp2a_u_s, + llvm::Intrinsic::nvvm_idp2a_s_u, + llvm::Intrinsic::nvvm_idp2a_s_s, + }; + return {ids[type], args}; +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 77b302155cb12..a02d33f50e0d2 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -587,6 +587,15 @@ func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i return } +// CHECK-LABEL: @dot_accumulate_2way +func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) { + // CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = false} : vector<2xi16>, vector<4xi8> + %1 = nvvm.dot.accumulate.2way %a_vec , %b_vec , %c {b_hi = false}: vector<2xi16>, vector<4xi8> + // CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = true} : vector<2xi16>, vector<4xi8> + %3 = nvvm.dot.accumulate.2way %a_vec , %b_vec , %c {b_hi = true}: vector<2xi16>, vector<4xi8> + return +} + // ----- // Just check these don't emit errors. diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index e892fc43f4a39..660d0a22dce9c 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -866,3 +866,41 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) %3 = nvvm.dot.accumulate.4way %a , %b , %c: vector<4xi8>, vector<4xi8> llvm.return } + +// ----- +// CHECK-LABEL: @nvvm_dot_accumulate_2way +llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) { + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}}) + %0 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = false} : vector<2xi16>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}}) + %1 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = true}: vector<2xi16>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}}) + %2 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = false}: vector<2xi16>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}}) + %3 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = true}: vector<2xi16>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}}) + %4 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = false}: vector<2xi16>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}}) + %5 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = true}: vector<2xi16>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}}) + %6 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = false}: vector<2xi16>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}}) + %7 = nvvm.dot.accumulate.2way %a , %b , %c {b_hi = true}: vector<2xi16>, vector<4xi8> + llvm.return +}