Skip to content

Commit 8dc24ec

Browse files
committed
[Intel] Use 'CTAEncodingAttr' after '49b7472'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b82b6a8 commit 8dc24ec

File tree

10 files changed

+34
-56
lines changed

10 files changed

+34
-56
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,7 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
13891389

13901390
// -----
13911391

1392-
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}>
1392+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0], [0]]}>
13931393
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
13941394
// CHECK-LABEL: test_get_program_id
13951395
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {

test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// COM: Tests reduction when threads_per_warp < num_warps.
44

5-
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
5+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0]}>
66
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} {
77
// CHECK-LABEL: reduce_problem_size_64_threads_per_warp_32
88
tt.func @reduce_problem_size_64_threads_per_warp_32(%f : tensor<2048xi32, #blocked>) {

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding",
310310
let parameters = (
311311
ins
312312
ArrayRefParameter<"unsigned">:$warpsPerCTA,
313-
"CTALayoutAttr":$CTALayout,
313+
"CTAEncodingAttr":$CTALayout,
314314
ArrayRefParameter<"unsigned">:$instrShape,
315315
"unsigned":$numBlocks,
316316
ArrayRefParameter<"unsigned">:$order,

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -146,22 +146,12 @@ DpasEncodingAttr::getRepOrderForOperand(OpIdx opIdx) const {
146146
return getOrderForDotOperand(unsigned(opIdx), rank, /*kMajor*/ true);
147147
}
148148

149-
SmallVector<unsigned> DpasEncodingAttr::getCTASplitNum() const {
149+
CTAEncodingAttr DpasEncodingAttr::getCTALayout() const {
150150
size_t rank = getWarpsPerCTA().size();
151-
SmallVector<unsigned> res(rank, 1);
152-
return res;
153-
}
154-
155-
SmallVector<unsigned> DpasEncodingAttr::getCTAOrder() const {
156-
size_t rank = getWarpsPerCTA().size();
157-
auto res = llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank)));
158-
return res;
159-
}
160-
161-
SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
162-
size_t rank = getWarpsPerCTA().size();
163-
SmallVector<unsigned> res(rank, 1);
164-
return res;
151+
SmallVector<unsigned> CTAsPerCGA(rank, 1);
152+
auto CTAOrder = llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank)));
153+
return CTAEncodingAttr::fromSplitParams(getContext(), CTAsPerCGA, CTAsPerCGA,
154+
CTAOrder);
165155
}
166156

167157
SmallVector<int64_t>
@@ -441,16 +431,8 @@ LinearLayout WarpEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
441431
llvm::report_fatal_error("NYI. WarpEncodingAttr::toLinearLayout");
442432
}
443433

444-
SmallVector<unsigned> WarpEncodingAttr::getCTAsPerCGA() const {
445-
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTAsPerCGA");
446-
}
447-
448-
SmallVector<unsigned> WarpEncodingAttr::getCTAOrder() const {
449-
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTAOrder");
450-
}
451-
452-
SmallVector<unsigned> WarpEncodingAttr::getCTASplitNum() const {
453-
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTASplitNum");
434+
CTAEncodingAttr WarpEncodingAttr::getCTALayout() const {
435+
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTALayout");
454436
}
455437

456438
Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) {
@@ -506,16 +488,16 @@ void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
506488
//===----------------------------------------------------------------------===//
507489

508490
namespace {
509-
std::optional<CTALayoutAttr> getCTALayoutOrError(
491+
std::optional<CTAEncodingAttr> getCTALayoutOrError(
510492
AsmParser &parser, std::optional<SmallVector<unsigned>> CTAsPerCGA,
511493
std::optional<SmallVector<unsigned>> CTASplitNum,
512494
std::optional<SmallVector<unsigned>> CTAOrder, unsigned rank) {
513495
if (CTAsPerCGA && CTASplitNum && CTAOrder) {
514-
return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum,
515-
*CTAOrder);
496+
return CTAEncodingAttr::fromSplitParams(parser.getContext(), *CTAsPerCGA,
497+
*CTASplitNum, *CTAOrder);
516498
}
517499
if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) {
518-
return CTALayoutAttr::getDefault(parser.getContext(), rank);
500+
return CTAEncodingAttr::getDefault(parser.getContext(), rank);
519501
}
520502
parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder "
521503
"must all be present or all be absent");
@@ -524,8 +506,8 @@ std::optional<CTALayoutAttr> getCTALayoutOrError(
524506

525507
// Print the CTALayout if it's not equal to the default.
526508
void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer,
527-
CTALayoutAttr layout, unsigned rank) {
528-
if (layout != CTALayoutAttr::getDefault(context, rank)) {
509+
CTAEncodingAttr layout, unsigned rank) {
510+
if (layout != CTAEncodingAttr::getDefault(context, rank)) {
529511
printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]"
530512
<< ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]"
531513
<< ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]";
@@ -536,7 +518,7 @@ void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer,
536518

537519
LogicalResult Subgroup2DBlockEncodingAttr::verify(
538520
function_ref<InFlightDiagnostic()> emitError,
539-
ArrayRef<unsigned> warpsPerCTA, CTALayoutAttr CTALayout,
521+
ArrayRef<unsigned> warpsPerCTA, CTAEncodingAttr CTALayout,
540522
ArrayRef<unsigned> instrShape, unsigned numBlocks, ArrayRef<unsigned> order,
541523
unsigned kWidth, unsigned threadsPerWarp) {
542524
if (instrShape.size() != 2) {
@@ -621,7 +603,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) {
621603
}
622604
}
623605

624-
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
606+
std::optional<CTAEncodingAttr> CTALayout = getCTALayoutOrError(
625607
parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size());
626608
if (!CTALayout.has_value())
627609
return {};
@@ -898,8 +880,10 @@ struct TritonIntelGPUInferLayoutInterface
898880
// Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA
899881
// should be like the other fields in blocked encoding, but I'm not sure how
900882
// to handle CTASplitNum.
901-
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
902-
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
883+
if (!all_of(src.getCTALayout().getCTAsPerCGA(),
884+
[](int32_t x) { return x == 1; }) ||
885+
!all_of(src.getCTALayout().getCTASplitNum(),
886+
[](int32_t x) { return x == 1; })) {
903887
return failure();
904888
}
905889

@@ -1074,7 +1058,7 @@ struct TritonIntelGPUInferLayoutInterface
10741058
auto dstOrder = inversePermutation(dstInvOrder);
10751059

10761060
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
1077-
auto CTALayout = CTALayoutAttr::get(
1061+
auto CTALayout = CTAEncodingAttr::fromSplitParams(
10781062
src.getContext(),
10791063
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
10801064
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ namespace {
2424
// for register layouts, and input dims [offset] for shared layouts.
2525
// - cgaLayout: Arrangement of multiple blocks, i.e. input dims [block].
2626
//
27-
// Note that this is inconsistent with the type name CTALayoutAttr. That type
27+
// Note that this is inconsistent with the type name CTAEncodingAttr. That type
2828
// is equivalent to our cgaLayout.
2929
//
30-
// IMO the name CTALayoutAttr is wrong. If we tried to be consistent anyway,
30+
// IMO the name CTAEncodingAttr is wrong. If we tried to be consistent anyway,
3131
// then we'd have to rename ctaLayout to "warpLayout". I think that's more
3232
// confusing than being inconsistent about "cgaLayout", especially when we have
3333
// to consider the size of the warpLayout (surely that's not the "warpSize").
@@ -57,8 +57,8 @@ LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
5757
// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups).
5858
//
5959
// See the nomenclature note at the top of the file for an explanation of why
60-
// this is called makeCgaLayout when it accepts a CTALayoutAttr.
61-
LinearLayout makeCgaLayout(CTALayoutAttr layout) {
60+
// this is called makeCgaLayout when it accepts a CTAEncodingAttr.
61+
LinearLayout makeCgaLayout(CTAEncodingAttr layout) {
6262
MLIRContext *ctx = layout.getContext();
6363
StringAttr kBlock = S("block");
6464

@@ -464,7 +464,7 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
464464
LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
465465

466466
return combineCtaCgaWithShape(std::move(tileLayout),
467-
CTALayoutAttr::getDefault(ctx, rank), shape);
467+
CTAEncodingAttr::getDefault(ctx, rank), shape);
468468
}
469469

470470
LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout,

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ struct PrefetchOpConversion
10761076
identityStandardND(S("warp"), warpsPerCTA, order);
10771077

10781078
return combineCtaCgaWithShape(std::move(ctaLayout),
1079-
CTALayoutAttr::getDefault(ctx, rank),
1079+
CTAEncodingAttr::getDefault(ctx, rank),
10801080
tensorShape);
10811081
}
10821082
};

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ using namespace mlir::triton;
2121
using ::mlir::LLVM::delinearize;
2222
using ::mlir::LLVM::SharedMemoryObject;
2323
using ::mlir::triton::gpu::BlockedEncodingAttr;
24-
using ::mlir::triton::gpu::CTALayoutAttr;
2524
using ::mlir::triton::gpu::DotOperandEncodingAttr;
2625
using ::mlir::triton::gpu::SliceEncodingAttr;
2726
using ::mlir::triton::gpu::intel::DpasEncodingAttr;

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "triton/Dialect/Triton/IR/Dialect.h"
1717
#include "triton/Dialect/Triton/IR/Types.h"
1818
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
19-
#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h"
2019
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
2120
#include "llvm/ADT/STLExtras.h"
2221
#include "llvm/ADT/SetVector.h"

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
349349
1,
350350
dpasEncoding.getWarpsPerCTA()[0]};
351351
constexpr std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5, 6};
352-
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
352+
CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank);
353353

354354
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
355355
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -407,7 +407,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
407407
dpasEncoding.getWarpsPerCTA()[1],
408408
dpasEncoding.getWarpsPerCTA()[0]};
409409
constexpr std::array<unsigned, rank> order{0, 1, 2, 3, 4};
410-
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
410+
CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank);
411411

412412
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
413413
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -440,7 +440,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
440440
dpasEncoding.getWarpsPerCTA()[1],
441441
dpasEncoding.getWarpsPerCTA()[0]};
442442
constexpr std::array<unsigned, rank> order{0, 1, 2, 3};
443-
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
443+
CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank);
444444

445445
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
446446
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -483,7 +483,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
483483
std::array<unsigned, rank> warpsPerCTA{dpasEncoding.getWarpsPerCTA()[1],
484484
dpasEncoding.getWarpsPerCTA()[0]};
485485
constexpr std::array<unsigned, rank> order{0, 1};
486-
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
486+
CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank);
487487

488488
auto parentEncoding = rewriter.getAttr<BlockedEncodingAttr>(
489489
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);

third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@ class LinearLayoutConversionsTest : public ::testing::Test {
3131

3232
// TODO: could put the getOrderForDotOperand in the builder?
3333
auto layout = Subgroup2DBlockEncodingAttr::get(
34-
&ctx, warpsPerCTA,
35-
CTALayoutAttr::get(
36-
&ctx, dpasLayout.getCTAsPerCGA(), // TODO: add to DpasLayout?
37-
dpasLayout.getCTASplitNum(), dpasLayout.getCTAOrder()),
38-
instrShape, numBlocks,
34+
&ctx, warpsPerCTA, dpasLayout.getCTALayout(), instrShape, numBlocks,
3935
getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth,
4036
dpasLayout.getThreadsPerWarp());
4137
return layout;

0 commit comments

Comments
 (0)