|
17 | 17 | #include "triton/Dialect/Triton/IR/Utility.h" |
18 | 18 | #include "llvm/Support/ErrorHandling.h" |
19 | 19 |
|
20 | | -#define DEBUG_TYPE "ttgpu_to_llvm" |
21 | | - |
22 | | -using namespace mlir; |
23 | | - |
24 | 20 | namespace mlir::LLVM::intel { |
25 | 21 |
|
26 | 22 | /// Create a predicated block, using \p cond as the condition and \p ops for the |
@@ -77,9 +73,6 @@ Block &createPredicatedBlock(RewriterBase &rewriter, Location loc, Value cond, |
77 | 73 | return createPredicatedBlock(rewriter, loc, cond, {}, thenOpsFn); |
78 | 74 | } |
79 | 75 |
|
80 | | -Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, |
81 | | - Type elemTy, Value pred); |
82 | | - |
83 | 76 | Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); |
84 | 77 | Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); |
85 | 78 | Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); |
@@ -129,16 +122,7 @@ static Value getModuleWarpSize(RewriterBase &rewriter, Location loc) { |
129 | 122 |
|
130 | 123 | } // namespace mlir::LLVM::intel |
131 | 124 |
|
132 | | -// ----------------------------------------------------------------------- |
133 | | -// Shared memory utilities |
134 | | -// ----------------------------------------------------------------------- |
135 | | -using ::mlir::triton::getMultiDimIndex; |
136 | | -using ::mlir::triton::gpu::BlockedEncodingAttr; |
137 | | -using ::mlir::triton::gpu::CTALayoutAttr; |
138 | | -using ::mlir::triton::gpu::DotOperandEncodingAttr; |
139 | | -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; |
140 | | -using ::mlir::triton::gpu::SliceEncodingAttr; |
141 | | -using ::mlir::triton::gpu::intel::DpasEncodingAttr; |
| 125 | +using mlir::triton::gpu::intel::DpasEncodingAttr; |
142 | 126 |
|
143 | 127 | static SmallVector<Value> |
144 | 128 | emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter, |
@@ -457,48 +441,6 @@ namespace mlir::triton::intel { |
457 | 441 | inline SmallVector<SmallVector<unsigned>> |
458 | 442 | emitOffsetForLayout(Attribute layout, RankedTensorType type); |
459 | 443 |
|
460 | | -inline SmallVector<SmallVector<unsigned>> |
461 | | -emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout, |
462 | | - RankedTensorType type) { |
463 | | - auto parentEncoding = sliceLayout.getParent(); |
464 | | - unsigned dim = sliceLayout.getDim(); |
465 | | - auto parentShape = sliceLayout.paddedShape(type.getShape()); |
466 | | - RankedTensorType parentTy = |
467 | | - RankedTensorType::get(parentShape, type.getElementType(), parentEncoding); |
468 | | - auto parentOffsets = ::intel::emitOffsetForLayout(parentEncoding, parentTy); |
469 | | - if (parentOffsets.empty()) |
470 | | - return {}; |
471 | | - |
472 | | - SmallVector<SmallVector<unsigned>> resultOffsets; |
473 | | - std::set<SmallVector<unsigned>> uniqueOffsets; |
474 | | - |
475 | | - for (unsigned i = 0; i < parentOffsets.size(); ++i) { |
476 | | - SmallVector<unsigned> offsets(parentOffsets[i].begin(), |
477 | | - parentOffsets[i].end()); |
478 | | - offsets.erase(offsets.begin() + dim); |
479 | | - if (auto [it, inserted] = uniqueOffsets.insert(offsets); inserted) { |
480 | | - resultOffsets.push_back(offsets); |
481 | | - } |
482 | | - } |
483 | | - |
484 | | - // It can happen that after deduplicating elements above, resultOffsets has |
485 | | - // fewer than getTotalElementsPerThread() elements. In that case repeat the |
486 | | - // sequence. |
487 | | - int elemsPerThread = triton::gpu::getTotalElemsPerThread(type); |
488 | | - assert(resultOffsets.size() > 0); |
489 | | - assert(elemsPerThread % resultOffsets.size() == 0); |
490 | | - int numRepeats = elemsPerThread / resultOffsets.size(); |
491 | | - SmallVector<SmallVector<unsigned>> ret; |
492 | | - for (int i = 0; i < numRepeats; ++i) { |
493 | | - for (unsigned j = 0; j < resultOffsets.size(); ++j) { |
494 | | - ret.push_back(SmallVector<unsigned>(resultOffsets[j])); |
495 | | - } |
496 | | - } |
497 | | - return ret; |
498 | | -} |
499 | | - |
500 | | -// |
501 | | - |
502 | 444 | // ----------------------------------------------------------------------- |
503 | 445 | // Get offsets / indices for any layout |
504 | 446 | // ----------------------------------------------------------------------- |
@@ -605,174 +547,11 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, |
605 | 547 | return multiDimIdx; |
606 | 548 | } |
607 | 549 |
|
608 | | -/* ---------------- */ |
609 | | -/* ---------------- */ |
610 | | -inline DenseMap<unsigned, Value> getSwizzledSharedPtrs( |
611 | | - Location loc, const TargetInfoBase &target, unsigned inVec, |
612 | | - RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, |
613 | | - Type resElemTy, const SharedMemoryObject &shrMemObj, RewriterBase &rewriter, |
614 | | - SmallVectorImpl<Value> &offsetVals, SmallVectorImpl<Value> &srcStrides) { |
615 | | - // This utility computes the pointers for accessing the provided swizzled |
616 | | - // shared memory layout `resSharedLayout`. More specifically, it computes, |
617 | | - // for all indices (row, col) of `srcEncoding` such that idx % inVec = 0, |
618 | | - // the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] + |
619 | | - // colOff) where : |
620 | | - // phase = (row // perPhase) % maxPhase |
621 | | - // rowOff = row |
622 | | - // colOff = colOffSwizzled + colOffOrdered |
623 | | - // colOffSwizzled = ((col // outVec) ^ phase) * outVec |
624 | | - // colOffOrdered = (col % outVec) // minVec * minVec |
625 | | - // |
626 | | - // Note 1: |
627 | | - // ------- |
628 | | - // Because swizzling happens at a granularity of outVec, we need to |
629 | | - // decompose the offset into a swizzled factor and a non-swizzled |
630 | | - // (ordered) factor |
631 | | - // |
632 | | - // Note 2: |
633 | | - // ------- |
634 | | - // If we have x, y, z of the form: |
635 | | - // x = 0b00000xxxx |
636 | | - // y = 0byyyyy0000 |
637 | | - // z = 0b00000zzzz |
638 | | - // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y |
639 | | - // This means that we can use some immediate offsets for shared memory |
640 | | - // operations. |
641 | | - auto dstPtrTy = shrMemObj.base.getType(); |
642 | | - auto dstOffset = dot(rewriter, loc, offsetVals, shrMemObj.strides); |
643 | | - Value dstPtrBase = gep(dstPtrTy, resElemTy, shrMemObj.base, dstOffset); |
644 | | - |
645 | | - auto srcEncoding = srcTy.getEncoding(); |
646 | | - auto srcShape = srcTy.getShape(); |
647 | | - auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy); |
648 | | - unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); |
649 | | - // swizzling params as described in TritonGPUAttrDefs.td |
650 | | - unsigned outVec = resSharedLayout.getVec(); |
651 | | - unsigned perPhase = resSharedLayout.getPerPhase(); |
652 | | - unsigned maxPhase = resSharedLayout.getMaxPhase(); |
653 | | - // Order |
654 | | - auto inOrder = triton::gpu::getOrder(srcEncoding); |
655 | | - auto outOrder = triton::gpu::getOrder(resSharedLayout); |
656 | | - assert(maxPhase == 1 || |
657 | | - outVec * maxPhase <= srcShape[outOrder[0]] && |
658 | | - "Swizzling would generate out of bounds memory accesses"); |
659 | | - // Tensor indices held by the current thread, as LLVM values |
660 | | - auto srcIndices = ::intel::emitIndices(loc, rewriter, target, srcEncoding, |
661 | | - srcTy, /*withCTAOffset=*/false); |
662 | | - // Swizzling with leading offsets (e.g. Hopper GMMA) |
663 | | - unsigned swizzlingByteWidth = 0; |
664 | | - if (resSharedLayout.getHasLeadingOffset()) { |
665 | | - if (perPhase == 4 && maxPhase == 2) |
666 | | - swizzlingByteWidth = 32; |
667 | | - else if (perPhase == 2 && maxPhase == 4) |
668 | | - swizzlingByteWidth = 64; |
669 | | - else if (perPhase == 1 && maxPhase == 8) |
670 | | - swizzlingByteWidth = 128; |
671 | | - else |
672 | | - llvm::report_fatal_error("Unsupported shared layout."); |
673 | | - } |
674 | | - unsigned numElemsPerSwizzlingRow = |
675 | | - swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth(); |
676 | | - Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow); |
677 | | - unsigned leadingDimOffset; |
678 | | - if (outOrder.size() >= 2) { |
679 | | - leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]]; |
680 | | - } else { |
681 | | - leadingDimOffset = numElemsPerSwizzlingRow; |
682 | | - } |
683 | | - |
684 | | - Value leadingDimOffsetVal = i32_val(leadingDimOffset); |
685 | | - // Return values |
686 | | - DenseMap<unsigned, Value> ret; |
687 | | - // cache for non-immediate offsets |
688 | | - DenseMap<unsigned, Value> cacheCol, cacheRow; |
689 | | - unsigned minVec = std::min(outVec, inVec); |
690 | | - Value strideRow = outOrder.size() >= 2 ? srcStrides[outOrder[1]] : i32_val(0); |
691 | | - Value strideCol = srcStrides[outOrder[0]]; |
692 | | - LDBG("getSwizzledSharedPtrs: perPhase = " |
693 | | - << perPhase << " maxPhase = " << maxPhase << " minVec = " << minVec |
694 | | - << " inVec = " << inVec << " outVec = " << outVec << " strideRow " |
695 | | - << strideRow << " strideCol " << strideCol); |
696 | | - for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { |
697 | | - Value offset = i32_val(0); |
698 | | - // Extract multi dimensional index for current element |
699 | | - auto idx = srcIndices[elemIdx]; |
700 | | - Value idxCol = idx[outOrder[0]]; // contiguous dimension |
701 | | - Value idxRow; |
702 | | - if (outOrder.size() >= 2) { |
703 | | - idxRow = idx[outOrder[1]]; // discontiguous dimension |
704 | | - } else { |
705 | | - idxRow = i32_val(0); |
706 | | - } |
707 | | - // compute phase = (row // perPhase) % maxPhase |
708 | | - Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase)); |
709 | | - // extract dynamic/static offset for immediate offsetting |
710 | | - unsigned immedateOffCol = 0; |
711 | | - unsigned immedateOffRow = 0; |
712 | | - if (leadingDimOffset) { |
713 | | - // hopper |
714 | | - offset = |
715 | | - mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal); |
716 | | - // Shrink by swizzling blocks |
717 | | - idxCol = urem(idxCol, numElemsPerSwizzlingRowVal); |
718 | | - strideRow = numElemsPerSwizzlingRowVal; |
719 | | - } |
720 | | - if (auto add = idxCol.getDefiningOp<LLVM::AddOp>()) { |
721 | | - if (auto _cst = add.getRhs().getDefiningOp<LLVM::ConstantOp>()) { |
722 | | - unsigned cst = |
723 | | - cast<IntegerAttr>(_cst.getValue()).getValue().getSExtValue(); |
724 | | - unsigned key = cst % (outVec * maxPhase); |
725 | | - cacheCol.insert({key, idxCol}); |
726 | | - idxCol = cacheCol[key]; |
727 | | - immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); |
728 | | - } |
729 | | - } |
730 | | - if (auto add = idxRow.getDefiningOp<LLVM::AddOp>()) { |
731 | | - if (auto _cst = add.getRhs().getDefiningOp<LLVM::ConstantOp>()) { |
732 | | - unsigned cst = |
733 | | - cast<IntegerAttr>(_cst.getValue()).getValue().getSExtValue(); |
734 | | - unsigned key = cst % (perPhase * maxPhase); |
735 | | - cacheRow.insert({key, idxRow}); |
736 | | - idxRow = cacheRow[key]; |
737 | | - immedateOffRow = cst / (perPhase * maxPhase) * (perPhase * maxPhase); |
738 | | - } |
739 | | - } |
740 | | - // row offset is simply row index |
741 | | - Value rowOff = mul(idxRow, strideRow); |
742 | | - // because swizzling happens at a granularity of outVec, we need to |
743 | | - // decompose the offset into a swizzled factor and a non-swizzled |
744 | | - // (ordered) factor: colOffSwizzled = ((col // outVec) ^ phase) * outVec |
745 | | - // colOffOrdered = (col % outVec) // minVec * minVec |
746 | | - Value colOffSwizzled = xor_(udiv(idxCol, i32_val(outVec)), phase); |
747 | | - colOffSwizzled = mul(colOffSwizzled, i32_val(outVec)); |
748 | | - Value colOffOrdered = urem(idxCol, i32_val(outVec)); |
749 | | - colOffOrdered = udiv(colOffOrdered, i32_val(minVec)); |
750 | | - colOffOrdered = mul(colOffOrdered, i32_val(minVec)); |
751 | | - Value colOff = add(colOffSwizzled, colOffOrdered); |
752 | | - // compute non-immediate offset |
753 | | - if (outOrder.size() == 3) |
754 | | - offset = add(offset, mul(idx[outOrder[2]], srcStrides[outOrder[2]])); |
755 | | - offset = add(offset, add(rowOff, mul(colOff, strideCol))); |
756 | | - Value currPtr = gep(dstPtrTy, resElemTy, dstPtrBase, offset); |
757 | | - // compute immediate offset |
758 | | - Value immediateOff; |
759 | | - if (outOrder.size() >= 2) { |
760 | | - immediateOff = |
761 | | - add(mul(i32_val(immedateOffRow), strideRow), i32_val(immedateOffCol)); |
762 | | - } else { |
763 | | - immediateOff = i32_val(immedateOffCol); |
764 | | - } |
765 | | - |
766 | | - ret[elemIdx] = gep(dstPtrTy, resElemTy, currPtr, immediateOff); |
767 | | - } |
768 | | - return ret; |
769 | | -} |
770 | | - |
771 | 550 | Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, |
772 | 551 | Value v); |
773 | 552 | Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter, |
774 | 553 | Value v, RoundingMode rounding); |
775 | 554 |
|
776 | 555 | } // namespace mlir::triton::intel |
777 | 556 |
|
778 | | -#endif |
| 557 | +#endif // TRITON_CONVERSION_TRITONINTELGPU_TO_LLVM_UTILITY_H |
0 commit comments