Skip to content

Commit d1dd736

Browse files
committed
Add force genisa
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent ad8ab0d commit d1dd736

File tree

4 files changed

+79
-27
lines changed

4 files changed

+79
-27
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5050
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5151
"TRITON_INTEL_FAST_MATH",
5252
"TRITON_INTEL_REDUCE_TRANSPOSE",
53+
"TRITON_INTEL_ENABLE_SIMD_REDUCE",
54+
"TRITON_INTEL_ENHANCED_ACCELERATION_MATMUL",
55+
"TRITON_INTEL_ENABLE_DPAS_WARP_SIZE_32",
56+
"TRITONGEN_FORCE_GENISA",
5357
// clang-format on
5458
};
5559

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def TritonGEN_PrecisionTypeAttr : I32EnumAttr<"PrecisionType",
5555
I32EnumAttrCase<"S4", 5, "i4">,
5656
I32EnumAttrCase<"S2", 6, "i2">,
5757
I32EnumAttrCase<"BF8", 7, "bf8">,
58-
I32EnumAttrCase<"TF32", 8, "tf32">,
59-
I32EnumAttrCase<"BF16", 9, "bf16">,
60-
I32EnumAttrCase<"FP16", 10, "f16">
58+
I32EnumAttrCase<"TF32", 10, "tf32">,
59+
I32EnumAttrCase<"BF16", 11, "bf16">,
60+
I32EnumAttrCase<"FP16", 12, "f16">
6161
]> {
6262
let cppNamespace = "::mlir::triton::TritonGEN";
6363
}

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,15 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() {
8585
return this->emitOpError(
8686
"1st operand (C) and result (D) should have the same type");
8787

88-
if (CTy.getNumElements() != getRc() || DTy.getNumElements() != getRc())
88+
auto useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA");
89+
90+
if (!useGenISA &&
91+
(CTy.getNumElements() != getRc() || DTy.getNumElements() != getRc()))
8992
return this->emitOpError("the dimension for 1st operand (C) and "
9093
"result (D) should match repeat count");
9194

9295
constexpr unsigned SD = 8;
93-
if (BTy.getNumElements() != SD)
96+
if (!useGenISA && BTy.getNumElements() != SD)
9497
return this->emitOpError("the dimension for the 3rd operand (B) should "
9598
"match the systolic depth of 8");
9699

@@ -141,7 +144,7 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() {
141144
case TritonGEN::PrecisionType::FP16:
142145
case TritonGEN::PrecisionType::U8:
143146
case TritonGEN::PrecisionType::S8:
144-
if (ATy.getNumElements() != getRc())
147+
if (!useGenISA && ATy.getNumElements() != getRc())
145148
return this->emitOpError("2nd operand (A) should have the same number of "
146149
"elements as repeat count");
147150
if (!AElemTy.isInteger(16))
@@ -303,6 +306,9 @@ LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() {
303306
if (verify2DBlockLoadHWRestriction(*this).failed())
304307
return failure();
305308

309+
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA"))
310+
return success();
311+
306312
if (verifyMatrixInput(*this).failed())
307313
return failure();
308314

@@ -367,6 +373,9 @@ LogicalResult TritonGEN::Matrix2DBlockStoreOp::verify() {
367373
if (verify2DBlockStoreHWRestriction(*this).failed())
368374
return failure();
369375

376+
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA"))
377+
return success();
378+
370379
if (verifyMatrixInput(*this).failed())
371380
return failure();
372381

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
#include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h"
4141
#include "intel/include/TritonGENToSPIRV/TritonGENToSPIRVPass.h"
4242

43+
#include <triton/Tools/Sys/GetEnv.hpp>
44+
45+
#include "GenIntrinsicHelper.h"
46+
4347
namespace mlir::triton {
4448
#define GEN_PASS_DEF_CONVERTTRITONGENTOLLVM
4549
#include "intel/include/TritonGENToLLVM/Passes.h.inc"
@@ -431,27 +435,48 @@ struct TritonMatrixDPASLowering
431435
if (cOrigTy != cTy)
432436
c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
433437

434-
std::string fnName = "__spirv_SubgroupMatrixMultiplyAccumulateINTEL";
435-
SmallVector<Type> argTypes{int32Ty, aTy, bTy, cTy, int32Ty};
436-
fnName = intel::mangle(fnName, argTypes);
437-
438-
TritonLLVMOpBuilder builder(loc, rewriter);
439-
Value kDim = builder.i32_val(8 /*systolic depth*/ *
440-
getNumOperandsPerDword(precisionA));
441-
SmallVector<Value> args{
442-
kDim, a, b, c,
443-
builder.i32_val(getMatrixMultiplyAccumulateOperandsVal(
444-
cOrigTy.getElementType(), precisionA))};
445-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
446-
/*other=*/LLVM::ModRefInfo::NoModRef,
447-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
448-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
449-
auto funcAttrs = intel::convergentNoUnwindWillReturnAttrs;
450-
funcAttrs.memEffectsAttr = memAttr;
438+
Value result;
439+
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) {
440+
MLIRContext *ctx = rewriter.getContext();
441+
auto builder = TritonLLVMOpBuilder(loc, rewriter);
442+
mlir::triton::gpu::intel::GenISA_Dpas dpasOp(rewriter, cTy, cTy, aTy,
443+
bTy);
444+
445+
// refer the call signature in GenISA
446+
result =
447+
dpasOp(rewriter, loc, c, a, b,
448+
builder.i32_val(
449+
static_cast<unsigned>(precisionA)), /*src0's precision*/
450+
builder.i32_val(
451+
static_cast<unsigned>(op.getPb())), /*src1's precision*/
452+
builder.i32_val(8), /*systolic depth*/
453+
builder.i32_val(8), /*repeate count*/
454+
builder.int_val(1, 0) /*is double = false*/)
455+
->getResult(0);
456+
} else {
457+
std::string fnName = "__spirv_SubgroupMatrixMultiplyAccumulateINTEL";
458+
SmallVector<Type> argTypes{int32Ty, aTy, bTy, cTy, int32Ty};
459+
fnName = intel::mangle(fnName, argTypes);
460+
461+
TritonLLVMOpBuilder builder(loc, rewriter);
462+
Value kDim = builder.i32_val(8 /*systolic depth*/ *
463+
getNumOperandsPerDword(precisionA));
464+
SmallVector<Value> args{
465+
kDim, a, b, c,
466+
builder.i32_val(getMatrixMultiplyAccumulateOperandsVal(
467+
cOrigTy.getElementType(), precisionA))};
468+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
469+
/*other=*/LLVM::ModRefInfo::NoModRef,
470+
/*argMem=*/LLVM::ModRefInfo::NoModRef,
471+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
472+
auto funcAttrs = intel::convergentNoUnwindWillReturnAttrs;
473+
funcAttrs.memEffectsAttr = memAttr;
474+
475+
result = intel::createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
476+
args, {}, funcAttrs)
477+
->getResult(0);
478+
}
451479

452-
Value result = intel::createDeviceFunctionCall(
453-
rewriter, fnName, cTy, argTypes, args, {}, funcAttrs)
454-
->getResult(0);
455480
if (cOrigTy != cTy)
456481
result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
457482

@@ -508,7 +533,8 @@ struct TritonMatrix2DBlockLoadLowering
508533
LogicalResult
509534
matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
510535
ConversionPatternRewriter &rewriter) const override {
511-
if (!isSPVBuiltinAvailable(op)) {
536+
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA") ||
537+
!isSPVBuiltinAvailable(op)) {
512538
// Fallback to GenISA interface.
513539
rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter));
514540
return success();
@@ -583,6 +609,12 @@ struct TritonMatrix2DBlockStoreLowering
583609
LogicalResult
584610
matchAndRewrite(TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
585611
ConversionPatternRewriter &rewriter) const override {
612+
// TODO: Remove GenISA lowering after PoC productization is completed.
613+
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) {
614+
rewriter.replaceOp(op, createGenISA2DBlockWrite(op, rewriter));
615+
return success();
616+
}
617+
586618
MLIRContext *ctx = rewriter.getContext();
587619
Location loc = op->getLoc();
588620
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -651,6 +683,13 @@ struct TritonMatrix2DBlockPrefetchLowering
651683
LogicalResult
652684
matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
653685
ConversionPatternRewriter &rewriter) const override {
686+
// TODO: Remove GenISA lowering after PoC productization is completed.
687+
bool useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA");
688+
if (useGenISA) {
689+
rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter));
690+
return success();
691+
}
692+
654693
MLIRContext *ctx = rewriter.getContext();
655694
Location loc = op->getLoc();
656695
auto b = TritonLLVMOpBuilder(loc, rewriter);

0 commit comments

Comments
 (0)