Skip to content

Commit b7dd07d

Browse files
authored
Introduce Subgroup 2D Block Encoding (#4193)
Add a new layout to describe the tensor layout with respect to the GPU compute hierarchy (register, lane, warp, block). This PR introduces the layout and adds its definition and basic functions to the Triton Intel GPU Dialect. The conversion to Linear Layout function has been added and unit tested through an Intel specific `LinearLayoutConversionsTest`. The layouts are unpacked - each register is assumed to be the size of the tensor type. However, the layout generation follows the convention described in https://github.khronos.org/SPIRV-Registry/extensions/INTEL/[SPV_INTEL_2d_block_io](https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html).html. While there may be some bugs, the goal is for any valid operation described in the SPIRV extension to be represented correctly with this layout. Currently the layout is unused other than for linear layout conversion testing purposes. I plan to leave this PR in draft until I have replaced the `block_io` attribute on the load ops with this layout - and then I plan to replace the linear layout code I added to `LoadStoreOpToLLVM.cpp`. That second task might prove challenging since I think the DPAS layouts do sometimes incorporate register packing schemes into the layout - but looking at the upstream layouts for NVIDIA and AMD MMA, specific packing is an implementation detail and not represented as part of the high-level layout encoding. cc #4192
1 parent 0bb3280 commit b7dd07d

File tree

7 files changed

+524
-0
lines changed

7 files changed

+524
-0
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
1818
LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout,
1919
ArrayRef<int64_t> shape);
2020

21+
LinearLayout
22+
subgroup2DBlockToLinearLayout(ArrayRef<int64_t> shape,
23+
intel::Subgroup2DBlockEncodingAttr layout,
24+
unsigned kWidth);
25+
2126
} // namespace mlir::triton::gpu
2227

2328
#endif // TRITON_DIALECT_TRITONINTELGPU_IR_LINEARLAYOUTCONVERSIONS_H

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,47 @@ def WarpEncodingAttr : DistributedEncoding<"WarpEncoding", "intel_warp_encoding"
280280
let hasCustomAssemblyFormat = 1;
281281
}
282282

283+
//===----------------------------------------------------------------------===//
284+
// Intel Subgroup2DBlock Encoding
285+
//===----------------------------------------------------------------------===//
286+
287+
def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", "subgroup_2d_block_encoding", [MmaEncodingTrait], TritonIntelGPU_Dialect> {
288+
let mnemonic = "subgroup_2d_block";
289+
290+
let description = [{
291+
An encoding for tensors produced via Intel Subgroup 2D Block IO operations.
292+
293+
The subgroup 2D block IO operations read or write two-dimensional blocks of data from a two-dimensional region of memory. The Subgroup 2D Block Encoding layout is parameterized by the block width, block height, and block count for the individual load instructions and the distribution and replication of loads across warps.
294+
295+
The SPV_INTEL_2d_block_io extension documentation provides more information on the subgroup 2D block IO operations and parameters: https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html
296+
297+
For the layout, the following parameters are required:
298+
- `instrShape` : contains the (height, width) block parameters for the block io operation
299+
- `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization)
300+
- `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor.
301+
- `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution
302+
- `order` : The order within the block, used to determine along which dimension to broadcast.
303+
- `kWidth` : Currently unused, but keeping because we will likely need it for layout conversions.
304+
- `CTALayout` : Describes how blocks are distributed among work-groups/thread blocks.
305+
}];
306+
307+
let parameters = (
308+
ins
309+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
310+
"CTALayoutAttr":$CTALayout,
311+
ArrayRefParameter<"unsigned">:$instrShape,
312+
"unsigned":$numBlocks,
313+
ArrayRefParameter<"unsigned">:$order,
314+
"unsigned":$kWidth,
315+
"unsigned":$threadsPerWarp
316+
);
317+
318+
let extraClassDeclaration = extraDistributedDeclaration # [{
319+
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
320+
}];
321+
322+
let hasCustomAssemblyFormat = 1;
323+
let genVerifyDecl = 1;
324+
}
325+
283326
#endif

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,173 @@ void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
495495
<< "}>";
496496
}
497497

498+
//===----------------------------------------------------------------------===//
499+
// Subgroup2DBlockEncodingAttr
500+
//===----------------------------------------------------------------------===//
501+
502+
namespace {
503+
std::optional<CTALayoutAttr> getCTALayoutOrError(
504+
AsmParser &parser, std::optional<SmallVector<unsigned>> CTAsPerCGA,
505+
std::optional<SmallVector<unsigned>> CTASplitNum,
506+
std::optional<SmallVector<unsigned>> CTAOrder, unsigned rank) {
507+
if (CTAsPerCGA && CTASplitNum && CTAOrder) {
508+
return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum,
509+
*CTAOrder);
510+
}
511+
if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) {
512+
return CTALayoutAttr::getDefault(parser.getContext(), rank);
513+
}
514+
parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder "
515+
"must all be present or all be absent");
516+
return std::nullopt;
517+
}
518+
519+
// Print the CTALayout if it's not equal to the default.
520+
void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer,
521+
CTALayoutAttr layout, unsigned rank) {
522+
if (layout != CTALayoutAttr::getDefault(context, rank)) {
523+
printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]"
524+
<< ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]"
525+
<< ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]";
526+
}
527+
}
528+
529+
} // namespace
530+
531+
LogicalResult Subgroup2DBlockEncodingAttr::verify(
532+
function_ref<InFlightDiagnostic()> emitError,
533+
ArrayRef<unsigned> warpsPerCTA, CTALayoutAttr CTALayout,
534+
ArrayRef<unsigned> instrShape, unsigned numBlocks, ArrayRef<unsigned> order,
535+
unsigned kWidth, unsigned threadsPerWarp) {
536+
if (instrShape.size() != 2) {
537+
return emitError() << "instrShape must be rank 2 but was: "
538+
<< instrShape.size();
539+
}
540+
if (order.size() != 2) {
541+
return emitError() << "order must be rank 2 but was " << order.size();
542+
}
543+
if (warpsPerCTA.size() != 2) {
544+
return emitError() << "warpsPerCTA must be rank 2 but was "
545+
<< warpsPerCTA.size();
546+
}
547+
if (!(kWidth == 1 || kWidth == 2 || kWidth == 4)) {
548+
return emitError() << "kWidth must be 1, 2 or 4, but was: " << kWidth;
549+
}
550+
if (!threadsPerWarp == 16) {
551+
return emitError() << "threadsPerWarp must be 16, but was: "
552+
<< threadsPerWarp;
553+
}
554+
return success();
555+
}
556+
557+
Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) {
558+
if (parser.parseLess().failed())
559+
return {};
560+
DictionaryAttr dict;
561+
if (parser.parseAttribute(dict).failed())
562+
return {};
563+
if (parser.parseGreater().failed())
564+
return {};
565+
566+
SmallVector<unsigned> warpsPerCTA;
567+
std::optional<SmallVector<unsigned>> CTAsPerCGA;
568+
std::optional<SmallVector<unsigned>> CTASplitNum;
569+
std::optional<SmallVector<unsigned>> CTAOrder;
570+
SmallVector<unsigned> instrShape;
571+
unsigned numBlocks = 0;
572+
SmallVector<unsigned> order;
573+
unsigned kWidth = 0;
574+
unsigned threadsPerWarp = 0;
575+
576+
for (const NamedAttribute &attr : dict) {
577+
if (attr.getName() == "warpsPerCTA") {
578+
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
579+
return {};
580+
}
581+
if (attr.getName() == "CTAsPerCGA") {
582+
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
583+
.failed())
584+
return {};
585+
}
586+
if (attr.getName() == "CTASplitNum") {
587+
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
588+
.failed())
589+
return {};
590+
}
591+
if (attr.getName() == "CTAOrder") {
592+
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
593+
.failed())
594+
return {};
595+
}
596+
if (attr.getName() == "instrShape") {
597+
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
598+
return {};
599+
}
600+
if (attr.getName() == "numBlocks") {
601+
if (parseUInt(parser, attr, numBlocks, "numBlocks").failed())
602+
return {};
603+
}
604+
if (attr.getName() == "order") {
605+
if (parseIntArrayAttr(parser, attr, order, "order").failed())
606+
return {};
607+
}
608+
if (attr.getName() == "kWidth") {
609+
if (parseUInt(parser, attr, kWidth, "kWidth").failed())
610+
return {};
611+
}
612+
if (attr.getName() == "threadsPerWarp") {
613+
if (parseUInt(parser, attr, threadsPerWarp, "threadsPerWarp").failed())
614+
return {};
615+
}
616+
}
617+
618+
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
619+
parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size());
620+
if (!CTALayout.has_value())
621+
return {};
622+
623+
return parser.getChecked<Subgroup2DBlockEncodingAttr>(
624+
parser.getContext(), warpsPerCTA, *CTALayout, instrShape, numBlocks,
625+
order, kWidth, threadsPerWarp);
626+
}
627+
628+
SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getRepOrder() const {
629+
return getMatrixOrder(getRank(), /*rowMajor*/ true);
630+
}
631+
632+
SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getCTAsPerCGA() const {
633+
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
634+
}
635+
636+
SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getCTAOrder() const {
637+
return SmallVector<unsigned>(getCTALayout().getCTAOrder());
638+
}
639+
640+
SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getCTASplitNum() const {
641+
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
642+
}
643+
644+
SmallVector<unsigned>
645+
Subgroup2DBlockEncodingAttr::getRepOrderForOperand(int opIdx) const {
646+
return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true);
647+
}
648+
649+
void Subgroup2DBlockEncodingAttr::print(AsmPrinter &printer) const {
650+
printer << "<{" << "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]";
651+
652+
maybePrintCTALayout(getContext(), printer, getCTALayout(), getRank());
653+
654+
printer << ", instrShape = [" << getInstrShape()
655+
<< "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder()
656+
<< "], kWidth=" << getKWidth()
657+
<< ", threadsPerWarp=" << getThreadsPerWarp() << "}>";
658+
}
659+
660+
LinearLayout
661+
Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
662+
return subgroup2DBlockToLinearLayout(shape, *this, getKWidth());
663+
}
664+
498665
//===----------------------------------------------------------------------===//
499666
// Dialect Interface
500667
//===----------------------------------------------------------------------===//

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,4 +523,119 @@ LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout,
523523
return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx());
524524
}
525525

526+
namespace {
527+
528+
static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
529+
ArrayRef<unsigned> shape,
530+
ArrayRef<unsigned> order,
531+
unsigned broadcastDim,
532+
StringAttr inDimName) {
533+
int rank = shape.size();
534+
auto dimNames = standardOutDimNames(ctx, rank);
535+
LinearLayout layout = LinearLayout::empty();
536+
537+
for (auto d : order) {
538+
if (d == broadcastDim) {
539+
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]);
540+
} else {
541+
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]);
542+
}
543+
}
544+
return layout;
545+
}
546+
547+
using basisT = std::vector<std::vector<int32_t>>;
548+
549+
// Creates a row major tile layout with register/lane input dimensions according
550+
// to the provided height, width, and threadsPerWarp. The relationship between
551+
// the width and threadsPerWarp determines the packing of rows across lanes:
552+
// - if width == threadsPerWarp:
553+
// block row elements are mapped to registers in row major order, i.e. one
554+
// column per lane
555+
// - if width < threadsPerWarp:
556+
// multiple rows are mapped to the first register to fill the warp, i.e.
557+
// width * rowsPerWarp = threadsPerWarp
558+
// - if width > threadsPerWarp:
559+
// multiple elements of each row are assigned to registers such that
560+
// packedElementsPerLane row values exist in consecutive registers for each
561+
// lane
562+
std::pair<basisT, basisT>
563+
createRegisterLaneBases(const int height, const int width,
564+
const unsigned threadsPerWarp) {
565+
const int packedElementsPerLane =
566+
mlir::ceil<int>(width, static_cast<int>(threadsPerWarp));
567+
568+
basisT laneBases;
569+
for (int i = packedElementsPerLane; i < width; i = i << 1) {
570+
laneBases.push_back({0, i});
571+
}
572+
573+
const int rowsPerWarp =
574+
mlir::ceil<int>(threadsPerWarp, 1 << laneBases.size());
575+
// Place subsequent rows into adjacent lanes until all lanes have been filled
576+
for (int i = 1; i < rowsPerWarp; i = i << 1) {
577+
laneBases.push_back({i, 0});
578+
}
579+
580+
basisT regBases;
581+
582+
// Add packed row-wise elements (width > threadsPerWarp) before adding columns
583+
for (int i = 1; i < packedElementsPerLane; i = i << 1) {
584+
regBases.push_back({0, i});
585+
}
586+
587+
for (int i = 1; i < height / rowsPerWarp; i = i << 1) {
588+
regBases.push_back({i * rowsPerWarp, 0});
589+
}
590+
591+
return std::make_pair(regBases, laneBases);
592+
}
593+
594+
} // namespace
595+
596+
LinearLayout
597+
subgroup2DBlockToLinearLayout(ArrayRef<int64_t> blockShape,
598+
intel::Subgroup2DBlockEncodingAttr layout,
599+
unsigned kWidth) {
600+
auto ctx = layout.getContext();
601+
int rank = blockShape.size();
602+
assert(rank == layout.getRank() && "unexpected block shape rank, layout rank "
603+
"and block shape rank must be equal");
604+
auto dimNames = standardOutDimNames(ctx, rank);
605+
auto loadTileSize = layout.getInstrShape();
606+
StringAttr kRegister = S("register");
607+
StringAttr kLane = S("lane");
608+
StringAttr kWarp = S("warp");
609+
610+
// Start by creating register/lane bases corresponding to the desired load
611+
// tile size
612+
auto [regBases, laneBases] = createRegisterLaneBases(
613+
loadTileSize[0], loadTileSize[1], layout.getThreadsPerWarp());
614+
615+
LinearLayout::BasesT bases;
616+
bases[kRegister] = regBases;
617+
bases[kLane] = laneBases;
618+
auto ctaLayout = LinearLayout(bases, dimNames);
619+
620+
assert(ctaLayout.getInDimSize(kLane) <= layout.getThreadsPerWarp() &&
621+
"number of lanes should not exceed threads per warp");
622+
623+
// Increasing the block count always increases the inner dimension for the
624+
// register/lane layout regardless of order
625+
ctaLayout *=
626+
LinearLayout::identity1D(layout.getNumBlocks(), kRegister, dimNames[1]);
627+
628+
// Broadcast the layout according to warpsPerCTA, then combine with the
629+
// overall CTALayout and reshape according to the provided blockShape.
630+
auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ true);
631+
auto order = layout.getOrder();
632+
assert(order.size() == 2 && "only rank 2 order supported");
633+
unsigned inner = order[0];
634+
635+
ctaLayout *= broadcastedDotOperandLayout(ctx, layout.getWarpsPerCTA(),
636+
warpOrder, inner, kWarp)
637+
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
638+
return combineCtaCgaWithShape(ctaLayout, layout.getCTALayout(), blockShape);
639+
}
640+
526641
} // namespace mlir::triton::gpu

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,20 @@ struct LoadOpConversion
19651965
}
19661966
Value elemSizeInBytes = b.i32_val(originalElemBits / 8);
19671967

1968+
LLVM_DEBUG({
1969+
const unsigned numLoads = numRepOuter * numLoadPerOutRepCluster *
1970+
numRepInner / numOperandsInnerDimPerLoad;
1971+
llvm::dbgs() << "Preparing to dispatch " << numLoads << " loads\n";
1972+
llvm::dbgs() << "Outer loads: " << numRepOuter * numLoadPerOutRepCluster
1973+
<< " (" << numLoadPerOutRepCluster
1974+
<< " per out rep cluster)\n";
1975+
llvm::dbgs() << "Inner loads: "
1976+
<< numRepInner / numOperandsInnerDimPerLoad << "\n";
1977+
llvm::dbgs() << "Load dimension: " << tileHeight << ", "
1978+
<< tileWidth * vBlocks << " (" << elemSizeInBits
1979+
<< " bits)\n";
1980+
});
1981+
19681982
ValueTable loadVals;
19691983
for (int outer = 0; outer < numRepOuter; ++outer) {
19701984
for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {

third_party/intel/unittest/Dialect/TritonIntelGPU/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,13 @@ add_triton_ut(
88
TritonIntelGPUTransforms
99
TritonNvidiaGPUTransforms
1010
)
11+
add_triton_ut(
12+
NAME LinearLayoutConversionsIntel
13+
SRCS LinearLayoutConversionsTest.cpp
14+
LIBS
15+
TritonGPUIR
16+
TritonGPUTransforms
17+
TritonIntelAnalysis
18+
TritonIntelGPUTransforms
19+
TritonNvidiaGPUTransforms
20+
)

0 commit comments

Comments
 (0)