Skip to content

Commit 4c58d3d

Browse files
committed
pass basic lowering test
1 parent 664b227 commit 4c58d3d

File tree

8 files changed

+466
-18
lines changed

8 files changed

+466
-18
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
716716
return getAttrs().getAs<ArrayAttr>("stride");
717717
}
718718

719+
ArrayAttr getBlockAttr() {
720+
return getAttrs().getAs<ArrayAttr>("block");
721+
}
722+
719723
}];
720724

721725
}
722726

727+
def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
728+
def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
729+
def MatrixAccessDirection :
730+
I32EnumAttr<"MatrixAccessDirection",
731+
"Matrix elements/vectors can have row or column direction", [
732+
RowOriented, ColOriented
733+
]> {
734+
let genSpecializedAttr = 0;
735+
let cppNamespace = "::mlir::xegpu";
736+
}
737+
def MatrixAccessDirectionAttr :
738+
EnumAttr<XeGPU_Dialect,
739+
MatrixAccessDirection,
740+
"matrix_access_direction">{
741+
let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
742+
let assemblyFormat = "`<` $value `>`";
743+
}
744+
723745
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,8 +1298,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
12981298
}
12991299

13001300
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
1301-
AllElementTypesMatch<["mem_desc", "res"]>,
1302-
AllRanksMatch<["mem_desc", "res"]>]> {
1301+
AllElementTypesMatch<["mem_desc", "res"]>]> {
13031302
let arguments = (ins XeGPU_MemDesc:$mem_desc,
13041303
Variadic<Index>: $offsets,
13051304
DenseI64ArrayAttr: $const_offsets,
@@ -1344,8 +1343,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13441343
}
13451344

13461345
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
1347-
AllElementTypesMatch<["mem_desc", "data"]>,
1348-
AllRanksMatch<["mem_desc", "data"]>]> {
1346+
AllElementTypesMatch<["mem_desc", "data"]>]> {
13491347
let arguments = (ins
13501348
XeGPU_ValueType:$data,
13511349
XeGPU_MemDesc:$mem_desc,

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
237237
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
238238
}
239239

240-
ArrayAttr getStrides() {
240+
ArrayAttr getStridesAttr() {
241241
auto layout = getMemLayout();
242242
if (layout && layout.hasAttr("stride")) {
243243
return layout.getStrides();
@@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
250250
Builder builder(getContext());
251251
return builder.getI64ArrayAttr(defaultStrides);
252252
}
253+
254+
/// Heuristic to determine if the MemDesc uses column-major layout,
255+
/// based on the rank and the value of the first stride dimension.
256+
bool isColMajor() {
257+
auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
258+
return getRank() == 2 && dim0 && dim0.getInt() == 1;
259+
}
260+
261+
// get the Blocking shape for a MemDescType, Which is represented
262+
// as an attribute in MemDescType. By default it is the shape
263+
// of the mdescTy
264+
SmallVector<int64_t> getBlockSize() {
265+
SmallVector<int64_t> size(getShape());
266+
MemLayoutAttr layout = getMemLayout();
267+
if (layout && layout.hasAttr("block")) {
268+
ArrayAttr attr = layout.getBlockAttr();
269+
size.clear();
270+
llvm::for_each(attr, [&](Attribute elem) {
271+
if (auto intElem = dyn_cast<IntegerAttr>(elem))
272+
size.push_back(intElem.getInt());
273+
});
274+
}
275+
return size;
276+
}
277+
278+
// Get strides as vector of integer.
279+
// If it contains block attribute, the strides are blocked strides.
280+
//
281+
// The blocking is applied against the original matrix shape
282+
// so that the linear offset is not impacted by the subview.
283+
//
284+
// It first computes the original matrix shape using the stride info,
285+
// then computes the number of blocks in each dimension of original shape,
286+
// then compute the outer block shape and stride,
287+
// then combines the inner and outer block shape and stride
288+
// e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
289+
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
290+
// for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
291+
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
292+
SmallVector<int64_t> getStrides();
293+
294+
/// Generates instructions to compute the linearize offset
295+
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
296+
// the strides of memory descriptor is always considered regardless of blocked or not
297+
Value getLinearOffsets(OpBuilder &builder,
298+
Location loc, ArrayRef<OpFoldResult> offsets);
299+
300+
253301
}];
254302

255303
let hasCustomAssemblyFormat = true;

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
#include <numeric>
3434

35+
#define DEBUG_TYPE "xegpu-to-xevm"
36+
3537
namespace mlir {
3638
#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
3739
#include "mlir/Conversion/Passes.h.inc"
@@ -60,6 +62,9 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
6062
return static_cast<int>(xevm::AddrSpace::GLOBAL);
6163
case xegpu::MemorySpace::SLM:
6264
return static_cast<int>(xevm::AddrSpace::SHARED);
65+
default:
66+
llvm_unreachable("Unknown XeGPU memory space");
67+
return static_cast<int>(xevm::AddrSpace::GLOBAL);
6368
}
6469
}
6570

@@ -366,6 +371,7 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
366371
Value baseAddr, Value offset, int64_t elemByteSize) {
367372
Value byteSize = arith::ConstantIntOp::create(
368373
rewriter, loc, rewriter.getI64Type(), elemByteSize);
374+
offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset);
369375
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
370376
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
371377
return newAddr;
@@ -503,6 +509,113 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
503509
}
504510
};
505511

512+
// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
513+
// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
514+
// 32 bits will be converted to 32 bits.
515+
class CreateMemDescOpPattern final
516+
: public OpConversionPattern<xegpu::CreateMemDescOp> {
517+
public:
518+
using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
519+
LogicalResult
520+
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
521+
ConversionPatternRewriter &rewriter) const override {
522+
// DEBUG: Print operation and types
523+
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
524+
TypedValue<MemRefType> src = op.getSource();
525+
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
526+
527+
// Create the result MemRefType with the same shape, element type, and memory space
528+
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
529+
530+
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
531+
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
532+
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
533+
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
534+
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero,
535+
ValueRange());
536+
rewriter.replaceOp(op, viewOp);
537+
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
538+
return success();
539+
}
540+
};
541+
542+
class MemDescSubviewOpPattern final
543+
: public OpConversionPattern<xegpu::MemDescSubviewOp> {
544+
public:
545+
using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
546+
LogicalResult
547+
matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
548+
ConversionPatternRewriter &rewriter) const override {
549+
return rewriter.notifyMatchFailure(
550+
op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
551+
}
552+
};
553+
554+
555+
template <typename OpType,
556+
typename = std::enable_if_t<llvm::is_one_of<
557+
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
558+
class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
559+
using OpConversionPattern<OpType>::OpConversionPattern;
560+
LogicalResult
561+
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
562+
ConversionPatternRewriter &rewriter) const override {
563+
564+
SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
565+
if (offsets.empty())
566+
return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
567+
568+
auto loc = op.getLoc();
569+
auto ctxt = rewriter.getContext();
570+
Value basePtrStruct = adaptor.getMemDesc();
571+
Value mdescVal = op.getMemDesc();
572+
// Load result or Store value Type can be vector or scalar.
573+
Value data;
574+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
575+
data = op.getResult();
576+
else
577+
data = adaptor.getData();
578+
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
579+
580+
int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth();
581+
// Element type must be multiple of 8 bits.
582+
if (elemBitWidth % 8 != 0)
583+
return rewriter.notifyMatchFailure(
584+
op, "Expected element type bit width to be multiple of 8.");
585+
int64_t elemByteSize = elemBitWidth / 8;
586+
587+
// Default memory space is SLM.
588+
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
589+
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
590+
591+
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
592+
593+
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct);
594+
595+
// Convert base pointer (ptr) to i64
596+
Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
597+
598+
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
599+
basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
600+
601+
// convert base pointer (i64) to LLVM pointer type
602+
basePtrLLVM =
603+
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
604+
605+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
606+
607+
Value loadOp =
608+
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
609+
rewriter.replaceOp(op, loadOp);
610+
} else {
611+
auto storeOp =
612+
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
613+
rewriter.eraseOp(op);
614+
}
615+
return success();
616+
}
617+
};
618+
506619
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
507620
using OpConversionPattern::OpConversionPattern;
508621
LogicalResult
@@ -785,6 +898,13 @@ struct ConvertXeGPUToXeVMPass
785898
auto i32Type = IntegerType::get(&getContext(), 32);
786899
return VectorType::get(8, i32Type);
787900
});
901+
// Convert MemDescType into flattened MemRefType for SLM
902+
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
903+
Type elemTy = type.getElementType();
904+
int numElems = type.getNumElements();
905+
return MemRefType::get(numElems, elemTy, AffineMap(), 3);
906+
});
907+
788908
typeConverter.addConversion([&](MemRefType type) -> Type {
789909
// Convert MemRefType to i64 type.
790910
return IntegerType::get(&getContext(), 64);
@@ -919,6 +1039,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
9191039
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
9201040
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
9211041
typeConverter, patterns.getContext());
1042+
patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1043+
LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1044+
CreateMemDescOpPattern, MemDescSubviewOpPattern>(
1045+
typeConverter, patterns.getContext());
9221046
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
9231047
patterns.getContext());
9241048
}

0 commit comments

Comments
 (0)