Skip to content

Commit 83683fc

Browse files
authored
[AMD] Remove specific scale preshuffle pattern match (#8247)
This commit switches to use a basic heuristic for improving support of preshuffled scale tensors--we try a few common scale tensor schemes and see which one gives the largest vectorization when global load.
1 parent e90d5a3 commit 83683fc

File tree

6 files changed

+162
-83
lines changed

6 files changed

+162
-83
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_
2+
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
6+
#include "triton/Tools/LinearLayout.h"
7+
#include <optional>
8+
9+
namespace mlir::triton::gpu {
10+
11+
// Given the result |dstLayout|, infer the source layout that we should use for
12+
// global load if we propagate through op def chain of |defOp|. Returns
13+
// std::nullopt if fails to infer or cannot reach a global load.
14+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
15+
inferSourceLoadLayout(const LinearLayout &dstLayout, Operation *defOp);
16+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
17+
inferSourceLoadLayout(LinearEncodingAttr dstLayout, Operation *defOp);
18+
19+
} // namespace mlir::triton::gpu
20+
21+
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_triton_library(TritonGPUTransforms
2727
ReorderInstructions.cpp
2828
CoalesceAsyncCopy.cpp
2929
Utility.cpp
30+
LayoutPropagationUtility.cpp
3031
WarpSpecialization/AutomaticWarpSpecialization.cpp
3132
WarpSpecialization/LoadMMASpecialization.cpp
3233
WarpSpecialization/Partition.cpp
@@ -35,6 +36,7 @@ add_triton_library(TritonGPUTransforms
3536
WarpSpecialization/PartitionLoops.cpp
3637
WarpSpecialization/PartitionScheduling.cpp
3738
WarpSpecialization/RewritePartitionDependencies.cpp
39+
3840
DEPENDS
3941
TritonGPUTransformsIncGen
4042

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
4+
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
5+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
6+
#include <optional>
7+
#include <utility>
8+
9+
namespace mlir::triton::gpu {
10+
11+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
12+
inferSourceLoadLayout(const LinearLayout &dstLayout, Operation *defOp) {
13+
if (!defOp)
14+
return std::nullopt;
15+
return inferSourceLoadLayout(
16+
LinearEncodingAttr::get(defOp->getContext(), dstLayout), defOp);
17+
}
18+
19+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
20+
inferSourceLoadLayout(LinearEncodingAttr dstLayout, Operation *defOp) {
21+
Attribute curLayout = dstLayout;
22+
Operation *curOp = defOp;
23+
while (curOp) {
24+
if (isa<triton::LoadOp>(curOp))
25+
break; // Found the load op; we are done here.
26+
27+
if (auto cvtOp = dyn_cast<ConvertLayoutOp>(curOp)) {
28+
// For convert op we keep the current layout to push through further.
29+
curOp = cvtOp.getSrc().getDefiningOp();
30+
} else {
31+
if (curOp->getNumOperands() != 1)
32+
break;
33+
curLayout = inferSrcEncoding(curOp, curLayout);
34+
curOp = curOp->getOperand(0).getDefiningOp();
35+
}
36+
}
37+
auto loadOp = dyn_cast_or_null<triton::LoadOp>(curOp);
38+
if (!loadOp)
39+
return std::nullopt;
40+
auto loadType = dyn_cast<RankedTensorType>(loadOp.getType());
41+
if (!loadType)
42+
return std::nullopt;
43+
44+
return std::make_pair(
45+
loadOp,
46+
toLinearLayout(loadType.getShape(), cast<LinearEncodingAttr>(curLayout)));
47+
}
48+
49+
} // namespace mlir::triton::gpu

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
#include "mlir/Dialect/SCF/IR/SCF.h"
77
#include "mlir/IR/Dominance.h"
88
#include "mlir/IR/IRMapping.h"
9-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
109
#include "triton/Analysis/AxisInfo.h"
1110
#include "triton/Dialect/Triton/IR/Dialect.h"
1211
#include "triton/Dialect/Triton/IR/Utility.h"
1312
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1413
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1514
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1615
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
17-
#include "llvm/ADT/SetOperations.h"
1816
#include "llvm/Support/Debug.h"
1917

2018
#define DEBUG_TYPE "ttg-utility"

python/test/unit/language/test_matmul.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -712,15 +712,19 @@ def generate_gemm_afp4wfp4_inputs(M, N, K):
712712
kernel_kwargs["matrix_instr_nonkdim"] = mfma_nonkdim
713713

714714
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
715-
_gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton, w_scales_triton, M, N, K,
716-
x.stride(0), x.stride(1), w.stride(0), w.stride(1), 0,
717-
triton_out.stride(0), triton_out.stride(1),
718-
x_scales_triton.stride(0), x_scales_triton.stride(1),
719-
w_scales_triton.stride(0), w_scales_triton.stride(1), BLOCK_M,
720-
BLOCK_N, BLOCK_K, mfma_nonkdim, preshuffle, num_warps=8,
721-
num_stages=1, **kernel_kwargs)
715+
k = _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton,
716+
w_scales_triton, M, N, K, x.stride(0), x.stride(1),
717+
w.stride(0), w.stride(1), 0, triton_out.stride(0),
718+
triton_out.stride(1), x_scales_triton.stride(0),
719+
x_scales_triton.stride(1), w_scales_triton.stride(0),
720+
w_scales_triton.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
721+
mfma_nonkdim, preshuffle, num_warps=8, num_stages=1,
722+
**kernel_kwargs)
722723
triton_out = triton_out.to(torch.float32)
723724
torch.testing.assert_close(torch_out, triton_out)
725+
if is_hip() and preshuffle:
726+
assert "tilesPerWarp = [2, 2]" in k.asm["ttgir"]
727+
assert "ds_read_u8" not in k.asm["amdgcn"]
724728

725729

726730
@pytest.mark.parametrize("M, N, K", [(1024, 512, 512), (998, 111, 512), (63, 128, 512)])

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 79 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010
#include "triton/Dialect/Triton/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h"
13+
#include "triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h"
1314
#include "triton/Tools/LayoutUtils.h"
15+
#include "triton/Tools/LinearLayout.h"
1416
#include "llvm/ADT/TypeSwitch.h"
1517

1618
namespace tt = mlir::triton;
1719
namespace ttg = mlir::triton::gpu;
1820
using ::mlir::LLVM::AMD::isChainDotHead;
1921
using ::mlir::LLVM::AMD::isChainDotTail;
20-
using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType;
21-
using mlir::triton::gpu::chooseScaledMfmaScaleLayout;
22+
23+
#undef DEBUG_TYPE
24+
#define DEBUG_TYPE "tritonamd-accelerate-matmul"
2225

2326
namespace mlir {
2427

@@ -217,6 +220,8 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
217220

218221
FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
219222
int mfmaVersion, int nonKDim) {
223+
using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType;
224+
220225
auto ctx = dot.getContext();
221226
int64_t inputKDim = dot.getA().getType().getShape().back();
222227
if (dot.getAElemType() == ScaleDotElemType::E2M1 && dot.getLhsKPack()) {
@@ -779,55 +784,72 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
779784
}
780785
};
781786

782-
template <typename Op> Op getDefOpBeforeConvertLayout(Value op) {
783-
while (auto cvtOp = op.getDefiningOp<ttg::ConvertLayoutOp>()) {
784-
op = cvtOp.getSrc();
785-
}
786-
return op.getDefiningOp<Op>();
787-
}
788-
789-
bool isScaleShuffled(Value scale) {
787+
// Figure out a best tilesPerWarp parameter that gives largest vector size for
788+
// global load for the given |scale| tensor feeding into dot_scaled op. Returns
789+
// the largest vector size and writes the choice to |result|.
790+
int deduceTilesPerWarp(TypedValue<RankedTensorType> scale, unsigned opIdx,
791+
unsigned nonKDim, ArrayRef<unsigned> warpsPerCTA,
792+
SmallVectorImpl<unsigned> *result) {
793+
std::array<unsigned, 2> chosen{1, 1};
794+
int vecSize = 1;
790795
if (!scale) {
791-
return false;
796+
result->assign(chosen.begin(), chosen.end());
797+
return vecSize;
792798
}
793799

794-
auto shape = cast<RankedTensorType>(scale.getType()).getShape();
795-
796-
int rank = shape.size();
797-
int blockNonK = shape[rank - 2];
798-
// 1 scale always scales 32 elements along K dim
799-
int blockK = shape[rank - 1] * 32;
800-
801-
auto reshapeOp2D = getDefOpBeforeConvertLayout<triton::ReshapeOp>(scale);
802-
if (!reshapeOp2D || reshapeOp2D.getType().getShape() != shape) {
803-
return false;
804-
}
805-
806-
const std::array<int, 7> transposeOrder{0, 5, 3, 1, 4, 2, 6};
807-
auto transOp =
808-
getDefOpBeforeConvertLayout<triton::TransOp>(reshapeOp2D.getSrc());
809-
if (!transOp || transOp.getOrder() != ArrayRef<int>(transposeOrder)) {
810-
return false;
811-
}
812-
813-
const std::array<int64_t, 7> reshape7DShape{
814-
blockNonK / 32, blockK / 32 / 8, 4, 16, 2, 2, 1};
815-
auto reshapeOp7D =
816-
getDefOpBeforeConvertLayout<triton::ReshapeOp>(transOp.getSrc());
817-
818-
if (!reshapeOp7D ||
819-
reshapeOp7D.getType().getShape() != ArrayRef<int64_t>(reshape7DShape)) {
820-
return false;
821-
}
822-
823-
return true;
824-
}
825-
826-
SmallVector<unsigned, 2> getTilesPerWarp(Value aScale, Value bScale) {
827-
if (isScaleShuffled(aScale) || isScaleShuffled(bScale)) {
828-
return {2, 2};
800+
// Source code have flexibility to preshuffle scale tensor to achieve better
801+
// global load vectorization. That preshuffle scheme is conveyed via some
802+
// tl.reshape and tl.trans op combinations. Instead of hardcoding one case or
803+
// pattern match the op chain here, we try certain scale tensor layouts and
804+
// see which one gives us better vectorization when pushed upwards to the
805+
// global load.
806+
//
807+
// For 16x16x128 scaled MFMA intrinsic, each thread only reads one i8 value.
808+
// For better vectorization, we prefer to stick 2x2 such intrinsic together so
809+
// each thread can read 4xi8 values.
810+
SmallVector<std::array<unsigned, 2>, 2> choices{{2, 2}, {1, 1}};
811+
for (const auto &choice : choices) {
812+
LLVM_DEBUG(llvm::dbgs()
813+
<< "choice: [" << choice[0] << ", " << choice[1] << "]\n");
814+
LinearLayout layout = ttg::chooseScaledMfmaScaleLayout(
815+
scale.getContext(), opIdx, scale.getType().getShape(), nonKDim, choice,
816+
warpsPerCTA);
817+
LLVM_DEBUG(llvm::dbgs() << "trying scale layout: " << layout << "\n");
818+
819+
// Infer source layout used for global load using the current scale layout.
820+
auto loadLayoutPair =
821+
ttg::inferSourceLoadLayout(layout, scale.getDefiningOp());
822+
if (!loadLayoutPair)
823+
continue;
824+
tt::LoadOp loadOp = loadLayoutPair->first;
825+
const LinearLayout &inferredLayout = loadLayoutPair->second;
826+
LLVM_DEBUG(llvm::dbgs()
827+
<< "inferred load layout: " << inferredLayout << "\n");
828+
829+
auto loadType = cast<RankedTensorType>(loadOp.getType());
830+
auto loadOrder = ttg::getOrder(loadType);
831+
auto loadCTALayout = ttg::getCTALayout(loadType.getEncoding());
832+
833+
// Reuse existing shared memory vectorization utilities by constructing a
834+
// pass through layout that does linear element mapping.
835+
MLIRContext *context = scale.getContext();
836+
auto passThruShared = ttg::SwizzledSharedEncodingAttr::get(
837+
context, 1, 1, 1, loadOrder, loadCTALayout);
838+
auto sharedLL =
839+
triton::gpu::toLinearLayout(loadType.getShape(), passThruShared);
840+
auto composedLL = inferredLayout.invertAndCompose(sharedLL).flattenOuts();
841+
auto [v, _] =
842+
largestVectorisation(context, composedLL, /*bitwidth=*/8, std::nullopt);
843+
844+
if (v > vecSize) {
845+
LLVM_DEBUG(llvm::dbgs() << "found vector size: " << v << "\n");
846+
chosen = choice;
847+
vecSize = v;
848+
break;
849+
}
829850
}
830-
return {1, 1};
851+
result->assign(chosen.begin(), chosen.end());
852+
return vecSize;
831853
}
832854

833855
class DecomposeAMDScaledBlocked final : public ttg::DecomposeScaledBlocked {
@@ -968,34 +990,18 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
968990
auto warpsPerTile =
969991
warpsPerTileMFMA(dotOp, oldShape, numWarps, {mDim, nDim});
970992

971-
// For scale tensor preshuffling, the minimum block size is 32x32x256.
972-
// When using MFMA16 instructions, each warp should compute two MFMA ops
973-
// along the non-K dimension. To support this, we must set tilesPerWarp to
974-
// {2, 2}. Failing to do so won't break correctness, but it will prevent
975-
// vectorized local_loads, as the data each thread needs won't be contiguous
976-
// due to the shuffle pattern. This requirement doesn’t apply to MFMA32
977-
// instructions, since only one MFMA op spans the non-K dimension at the
978-
// minimal shuffling size.
979-
SmallVector<unsigned> tilesPerWarp = getTilesPerWarp(aScale, bScale);
980-
981-
if (rank == 3) {
982-
tilesPerWarp.insert(tilesPerWarp.begin(), 1);
983-
}
993+
SmallVector<unsigned, 2> tilesA{1, 1}, tilesB{1, 1}, tilesPerWarp;
994+
int vecA = deduceTilesPerWarp(aScale, 0, mDim, warpsPerTile, &tilesA);
995+
int vecB = deduceTilesPerWarp(bScale, 1, mDim, warpsPerTile, &tilesB);
996+
tilesPerWarp = vecA > vecB ? tilesA : tilesB;
997+
LLVM_DEBUG(llvm::dbgs() << "chosen tilesPerWarp: [" << tilesPerWarp[0]
998+
<< ", " << tilesPerWarp[1] << "]\n");
984999

9851000
// Always use transposed mfma layout. This enables larger vectorization
9861001
// for global store instructions.
987-
mlir::Attribute mfmaEnc;
988-
if (llvm::any_of(tilesPerWarp, [](int x) { return x != 1; })) {
989-
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
990-
ctx, /*verison=*/mfmaVersion, warpsPerTile, tilesPerWarp,
991-
/*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout,
992-
oldRetType.getElementType());
993-
} else {
994-
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
995-
ctx, /*verison=*/mfmaVersion, warpsPerTile,
996-
/*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout,
997-
oldRetType.getElementType());
998-
}
1002+
mlir::Attribute mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
1003+
ctx, /*verison=*/mfmaVersion, warpsPerTile, tilesPerWarp, mDim, nDim,
1004+
/*isTransposed=*/true, ctaLayout, oldRetType.getElementType());
9991005

10001006
auto newRetType =
10011007
RankedTensorType::get(oldShape, oldRetType.getElementType(), mfmaEnc);
@@ -1097,7 +1103,7 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
10971103
shape = llvm::to_vector(scale.getType().getShape());
10981104
}
10991105

1100-
LinearLayout newLL = chooseScaledMfmaScaleLayout(
1106+
LinearLayout newLL = ttg::chooseScaledMfmaScaleLayout(
11011107
ctx, idx, shape, mDim, tilesPerWarp, warpsPerTile);
11021108

11031109
Attribute newScaleEncoding = ttg::LinearEncodingAttr::get(ctx, newLL);
@@ -1515,7 +1521,6 @@ struct TritonAMDGPUAccelerateMatmulPass
15151521
using Base::Base;
15161522

15171523
void runOnOperation() override {
1518-
15191524
MLIRContext *context = &getContext();
15201525
ModuleOp m = getOperation();
15211526

0 commit comments

Comments
 (0)