|
40 | 40 | #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h"
|
41 | 41 | #include "intel/include/TritonGENToSPIRV/TritonGENToSPIRVPass.h"
|
42 | 42 |
|
| 43 | +#include <triton/Tools/Sys/GetEnv.hpp> |
| 44 | + |
| 45 | +#include "GenIntrinsicHelper.h" |
| 46 | + |
43 | 47 | namespace mlir::triton {
|
44 | 48 | #define GEN_PASS_DEF_CONVERTTRITONGENTOLLVM
|
45 | 49 | #include "intel/include/TritonGENToLLVM/Passes.h.inc"
|
@@ -431,27 +435,48 @@ struct TritonMatrixDPASLowering
|
431 | 435 | if (cOrigTy != cTy)
|
432 | 436 | c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
|
433 | 437 |
|
434 |
| - std::string fnName = "__spirv_SubgroupMatrixMultiplyAccumulateINTEL"; |
435 |
| - SmallVector<Type> argTypes{int32Ty, aTy, bTy, cTy, int32Ty}; |
436 |
| - fnName = intel::mangle(fnName, argTypes); |
437 |
| - |
438 |
| - TritonLLVMOpBuilder builder(loc, rewriter); |
439 |
| - Value kDim = builder.i32_val(8 /*systolic depth*/ * |
440 |
| - getNumOperandsPerDword(precisionA)); |
441 |
| - SmallVector<Value> args{ |
442 |
| - kDim, a, b, c, |
443 |
| - builder.i32_val(getMatrixMultiplyAccumulateOperandsVal( |
444 |
| - cOrigTy.getElementType(), precisionA))}; |
445 |
| - auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( |
446 |
| - /*other=*/LLVM::ModRefInfo::NoModRef, |
447 |
| - /*argMem=*/LLVM::ModRefInfo::NoModRef, |
448 |
| - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); |
449 |
| - auto funcAttrs = intel::convergentNoUnwindWillReturnAttrs; |
450 |
| - funcAttrs.memEffectsAttr = memAttr; |
| 438 | + Value result; |
| 439 | + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) { |
| 440 | + MLIRContext *ctx = rewriter.getContext(); |
| 441 | + auto builder = TritonLLVMOpBuilder(loc, rewriter); |
| 442 | + mlir::triton::gpu::intel::GenISA_Dpas dpasOp(rewriter, cTy, cTy, aTy, |
| 443 | + bTy); |
| 444 | + |
| 445 | + // refer the call signature in GenISA |
| 446 | + result = |
| 447 | + dpasOp(rewriter, loc, c, a, b, |
| 448 | + builder.i32_val( |
| 449 | + static_cast<unsigned>(precisionA)), /*src0's precision*/ |
| 450 | + builder.i32_val( |
| 451 | + static_cast<unsigned>(op.getPb())), /*src1's precision*/ |
| 452 | + builder.i32_val(8), /*systolic depth*/ |
| 453 | + builder.i32_val(8), /*repeate count*/ |
| 454 | + builder.int_val(1, 0) /*is double = false*/) |
| 455 | + ->getResult(0); |
| 456 | + } else { |
| 457 | + std::string fnName = "__spirv_SubgroupMatrixMultiplyAccumulateINTEL"; |
| 458 | + SmallVector<Type> argTypes{int32Ty, aTy, bTy, cTy, int32Ty}; |
| 459 | + fnName = intel::mangle(fnName, argTypes); |
| 460 | + |
| 461 | + TritonLLVMOpBuilder builder(loc, rewriter); |
| 462 | + Value kDim = builder.i32_val(8 /*systolic depth*/ * |
| 463 | + getNumOperandsPerDword(precisionA)); |
| 464 | + SmallVector<Value> args{ |
| 465 | + kDim, a, b, c, |
| 466 | + builder.i32_val(getMatrixMultiplyAccumulateOperandsVal( |
| 467 | + cOrigTy.getElementType(), precisionA))}; |
| 468 | + auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( |
| 469 | + /*other=*/LLVM::ModRefInfo::NoModRef, |
| 470 | + /*argMem=*/LLVM::ModRefInfo::NoModRef, |
| 471 | + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); |
| 472 | + auto funcAttrs = intel::convergentNoUnwindWillReturnAttrs; |
| 473 | + funcAttrs.memEffectsAttr = memAttr; |
| 474 | + |
| 475 | + result = intel::createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, |
| 476 | + args, {}, funcAttrs) |
| 477 | + ->getResult(0); |
| 478 | + } |
451 | 479 |
|
452 |
| - Value result = intel::createDeviceFunctionCall( |
453 |
| - rewriter, fnName, cTy, argTypes, args, {}, funcAttrs) |
454 |
| - ->getResult(0); |
455 | 480 | if (cOrigTy != cTy)
|
456 | 481 | result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
|
457 | 482 |
|
@@ -508,7 +533,8 @@ struct TritonMatrix2DBlockLoadLowering
|
508 | 533 | LogicalResult
|
509 | 534 | matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
|
510 | 535 | ConversionPatternRewriter &rewriter) const override {
|
511 |
| - if (!isSPVBuiltinAvailable(op)) { |
| 536 | + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA") || |
| 537 | + !isSPVBuiltinAvailable(op)) { |
512 | 538 | // Fallback to GenISA interface.
|
513 | 539 | rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter));
|
514 | 540 | return success();
|
@@ -583,6 +609,12 @@ struct TritonMatrix2DBlockStoreLowering
|
583 | 609 | LogicalResult
|
584 | 610 | matchAndRewrite(TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
|
585 | 611 | ConversionPatternRewriter &rewriter) const override {
|
| 612 | + // TODO: Remove GenISA lowering after PoC productization is completed. |
| 613 | + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) { |
| 614 | + rewriter.replaceOp(op, createGenISA2DBlockWrite(op, rewriter)); |
| 615 | + return success(); |
| 616 | + } |
| 617 | + |
586 | 618 | MLIRContext *ctx = rewriter.getContext();
|
587 | 619 | Location loc = op->getLoc();
|
588 | 620 | auto b = TritonLLVMOpBuilder(loc, rewriter);
|
@@ -651,6 +683,13 @@ struct TritonMatrix2DBlockPrefetchLowering
|
651 | 683 | LogicalResult
|
652 | 684 | matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
|
653 | 685 | ConversionPatternRewriter &rewriter) const override {
|
| 686 | + // TODO: Remove GenISA lowering after PoC productization is completed. |
| 687 | + bool useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA"); |
| 688 | + if (useGenISA) { |
| 689 | + rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter)); |
| 690 | + return success(); |
| 691 | + } |
| 692 | + |
654 | 693 | MLIRContext *ctx = rewriter.getContext();
|
655 | 694 | Location loc = op->getLoc();
|
656 | 695 | auto b = TritonLLVMOpBuilder(loc, rewriter);
|
|
0 commit comments