Skip to content

Commit e301b43

Browse files
committed
WIP: use the subgroup 2d block layout in LoadStoreOpToLLVM
add missing definition WIP: use new encoding in load store op to llvm
1 parent f7e81ce commit e301b43

File tree

4 files changed

+53
-9
lines changed

4 files changed

+53
-9
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ Attribute inferSrcEncoding(Operation *op, Attribute encoding);
3333
// Retuns true if the operation is an expensive load or store operation.
3434
bool isExpensiveLoadOrStore(Operation *op);
3535

36+
// Returns true if the tensor type has a subgroup 2d block io encoding
37+
bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType);
38+
3639
// Returns true if the tensor type has a dot dpas encoding.
3740
bool hasDotDpasEncoding(RankedTensorType tensorType);
3841

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
302302

303303
// Only lower loadOp with dpas layout encoding.
304304
auto tensorTy = cast<RankedTensorType>(op.getType());
305-
return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy);
305+
return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy) ||
306+
hasSubgroup2DBlockEncoding(tensorTy);
306307
}
307308

308309
template <
@@ -1416,12 +1417,31 @@ struct LoadOpConversion
14161417
auto tensorType = cast<RankedTensorType>(resultType);
14171418

14181419
const bool memoryRowMajor = isMemoryRowMajor(op);
1419-
DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType);
1420+
1421+
auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType {
1422+
for (OpOperand user : opResult.getUsers()) {
1423+
if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner())) {
1424+
return cast<RankedTensorType>(cvt.getResult().getType());
1425+
// return getDpasLayout(cvt.getResult().getType());
1426+
}
1427+
}
1428+
llvm_unreachable("expected to find a cvt op with dpas layout");
1429+
};
1430+
1431+
auto dpasTensorType = hasSubgroup2DBlockEncoding(tensorType)
1432+
? getDpasTypeFromCVTOp(op.getResult())
1433+
: tensorType;
1434+
llvm::errs() << "using dpas tensor type: " << dpasTensorType << "\n";
1435+
DpasEncodingAttr dpasLayout = getDpasLayout(dpasTensorType);
1436+
1437+
DpasEncodingAttr::OpIdx opIdx = getOpIdx(dpasTensorType);
14201438

14211439
LLVM_DEBUG(llvm::dbgs() << "Tensor type for op " << int(opIdx) << ": "
14221440
<< tensorType << "\n");
14231441

14241442
Attribute encoding = tensorType.getEncoding();
1443+
// TODO: this gives us the linear layour corresponding
1444+
// to the subgroup 2d block encoding, not the dpas encoding...
14251445
std::optional<LinearLayout> llEncoding =
14261446
cast<DistributedEncodingTrait>(encoding).toLinearLayout(
14271447
tensorType.getShape());
@@ -1440,14 +1460,21 @@ struct LoadOpConversion
14401460
Type eltTy = tensorType.getElementType();
14411461
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
14421462

1443-
auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
1444-
cast<DistributedEncodingTrait>(encoding), tensorType.getShape(),
1445-
memoryRowMajor, elemSizeInBits / 8, rewriter.getContext());
1446-
unsigned tileHeight = tileParams[0];
1447-
const unsigned tileWidth = tileParams[1];
1448-
const unsigned vBlocks = tileParams[2];
1463+
auto getTileParams = [&]() -> std::tuple<unsigned, unsigned, unsigned> {
1464+
if (hasSubgroup2DBlockEncoding(tensorType)) {
1465+
auto encoding =
1466+
cast<Subgroup2DBlockEncodingAttr>(tensorType.getEncoding());
1467+
auto shape = encoding.getInstrShape();
1468+
return std::make_tuple(shape[0], shape[1], encoding.getNumBlocks());
1469+
} else {
1470+
auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
1471+
cast<DistributedEncodingTrait>(encoding), tensorType.getShape(),
1472+
memoryRowMajor, elemSizeInBits / 8, rewriter.getContext());
1473+
return std::make_tuple(tileParams[0], tileParams[1], tileParams[2]);
1474+
}
1475+
};
1476+
auto [tileHeight, tileWidth, vBlocks] = getTileParams();
14491477

1450-
DpasEncodingAttr dpasLayout = getDpasLayout(tensorType);
14511478
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
14521479
unsigned numElems = getTotalElemsPerThread(resultType);
14531480
SmallVector<int64_t> numReps =
@@ -1617,6 +1644,7 @@ struct LoadOpConversion
16171644
// input operands to DPAS.
16181645
// TODO: add support for int4 and int2.
16191646
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
1647+
llvm::errs() << "opsPerChannel = " << opsPerChannel << "\n";
16201648
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
16211649
(opsPerChannel == 2 && elemSizeInBits == 16) ||
16221650
(opsPerChannel == 1 && elemSizeInBits == 32)) {
@@ -1840,6 +1868,8 @@ struct LoadOpConversion
18401868
unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst *
18411869
numOperandsOuterDimPerLoad *
18421870
numOperandsInnerDimPerLoad;
1871+
llvm::errs() << "num values per load = " << numValuesPerLoad << "\n";
1872+
llvm::errs() << "loadResultElemType = " << loadResultElemType << "\n";
18431873
Type load2DGenXType =
18441874
LLVM::getVectorType(loadResultElemType, numValuesPerLoad);
18451875

@@ -2187,6 +2217,8 @@ struct LoadOpConversion
21872217
}
21882218

21892219
Type llvmResultStructTy = typeConverter->convertType(op.getType());
2220+
llvm::errs() << "op.getType() " << op.getType() << "\n";
2221+
llvm::errs() << "llvmResultStructTy: " << llvmResultStructTy << "\n";
21902222
Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals,
21912223
rewriter, llvmResultStructTy);
21922224
rewriter.replaceOp(op, {resultStruct});

third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class TritonIntelGPUReduceDataDuplicationPass
3030
auto srcEncoding = srcType.getEncoding();
3131
if (isa<triton::gpu::SharedEncodingTrait>(srcEncoding))
3232
return;
33+
if (isa<intel::Subgroup2DBlockEncodingAttr>(srcEncoding))
34+
return;
3335
auto dstDotOp =
3436
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
3537
if (!dstDotOp)

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ bool isExpensiveLoadOrStore(Operation *op) {
153153
return false;
154154
}
155155

156+
bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType) {
157+
if (!tensorType.getEncoding())
158+
return false;
159+
160+
return isa<ttgi::Subgroup2DBlockEncodingAttr>(tensorType.getEncoding());
161+
}
162+
156163
bool hasDotDpasEncoding(RankedTensorType tensorType) {
157164
if (!tensorType.getEncoding())
158165
return false;

0 commit comments

Comments
 (0)