Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,10 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().getAs<ArrayAttr>("stride");
}

ArrayAttr getBlockAttr() {
return getAttrs().getAs<ArrayAttr>("block");
}

}];

}
Expand Down
22 changes: 14 additions & 8 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}

def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
AllElementTypesMatch<["mem_desc", "res"]>,
AllRanksMatch<["mem_desc", "res"]>]> {
AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the descrioption of the op with the meaning of block_io

);
let results = (outs XeGPU_ValueType:$res);
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
Expand Down Expand Up @@ -1336,21 +1336,24 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}

ArrayRef<int64_t> getDataShape() {
return getRes().getType().getShape();
auto resTy = getRes().getType();
if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
return vecTy.getShape();
return {};
}
}];

let hasVerifier = 1;
}

def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
AllElementTypesMatch<["mem_desc", "data"]>,
AllRanksMatch<["mem_desc", "data"]>]> {
AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
XeGPU_ValueType:$data,
AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<UnitAttr>:$subgroup_block_io,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update description.

OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
Expand Down Expand Up @@ -1378,7 +1381,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}

ArrayRef<int64_t> getDataShape() {
return getData().getType().getShape();
auto DataTy = getData().getType();
if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
return vecTy.getShape();
return {};
}

}];
Expand Down
58 changes: 56 additions & 2 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,73 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}

ArrayAttr getStrides() {
ArrayAttr getStridesAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
return layout.getStrides();
}

// derive and return default strides
SmallVector<int64_t> defaultStrides;
llvm::append_range(defaultStrides, getShape().drop_front());
llvm::append_values(defaultStrides, 1);
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}

ArrayAttr getBlockAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("block")) {
return layout.getBlockAttr();
}
Builder builder(getContext());
return builder.getI64ArrayAttr({});
}

/// Heuristic to determine if the MemDesc uses column-major layout,
/// based on the rank and the value of the first stride dimension.
bool isColMajor() {
auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
return getRank() == 2 && dim0 && dim0.getInt() == 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why check dim0 exists? it should always exist right?

}

// get the Blocking shape for a MemDescType, Which is represented
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Capitalize the first letter of all comment sentences per coding standards.

// as an attribute in MemDescType. By default it is the shape
// of the mdescTy
SmallVector<int64_t> getBlockShape() {
SmallVector<int64_t> size(getShape());
ArrayAttr blockAttr = getBlockAttr();
if (!blockAttr.empty()) {
size.clear();
for (auto attr : blockAttr.getValue()) {
size.push_back(cast<IntegerAttr>(attr).getInt());
}
}
return size;
}

// Get strides as vector of integer.
// If it contains block attribute, the strides are blocked strides.
//
// The blocking is applied against the original matrix shape
// so that the linear offset is not impacted by the subview.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the subview you refer here? is it the subview to specific block?

//
// It first computes the original matrix shape using the stride info,
// then computes the number of blocks in each dimension of original shape,
// then compute the outer block shape and stride,
// then combines the inner and outer block shape and stride
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use code quotes for (mem_desc) for code examples. That way doxygen will generate more readable docs.

// e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
// for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
SmallVector<int64_t> getStrideShape();

/// Generates instructions to compute the linearize offset
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
// the strides of memory descriptor is always considered regardless of blocked or not
Value getLinearOffsets(OpBuilder &builder,
Location loc, ArrayRef<OpFoldResult> offsets);


}];

let hasCustomAssemblyFormat = true;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
Expand Down
181 changes: 175 additions & 6 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
}
llvm_unreachable("Unknown XeGPU memory space");
}

// Get same bitwidth flat vector type of new element type.
Expand Down Expand Up @@ -184,6 +186,7 @@ class CreateNdDescToXeVMPattern
int64_t rank = mixedSizes.size();
if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");

auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
Expand Down Expand Up @@ -362,10 +365,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {

// Add a builder that creates
// offset * elemByteSize + baseAddr
static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
Value baseAddr, Value offset, int64_t elemByteSize) {
static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
Location loc, Value baseAddr, Value offset,
int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemByteSize);
rewriter, loc, baseAddr.getType(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
Expand Down Expand Up @@ -440,7 +444,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
basePtrI64 =
addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
Expand Down Expand Up @@ -503,6 +508,159 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};

// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
// 32 bits will be converted to 32 bits.
class CreateMemDescOpPattern final
: public OpConversionPattern<xegpu::CreateMemDescOp> {
public:
using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why to use getType directly?


// Create the result MemRefType with the same shape, element type, and
// memory space
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);

Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
op.getSource(), zero, ValueRange());
rewriter.replaceOp(op, viewOp);
return success();
}
};

class MemDescSubviewOpPattern final
: public OpConversionPattern<xegpu::MemDescSubviewOp> {
public:
using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(
op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly prevents it, and why should the pattern exist then? Such limitations should be clarified in the op description to not surprise users only after the xegpu code is ready for lowering.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, it's still an open question how we'll help users across different generations.
A separate arch specific verifier pass might be needed.

For now, I'd be leaning toward leaving this as is just to slightly improve discoverability when ppl start wondering why this is missing or not working.

}
};

template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
if (offsets.empty())
return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");

auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
Value basePtrStruct = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
Value data;
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
data = op.getResult();
else
data = adaptor.getData();
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());

int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
// Element type must be multiple of 8 bits.
if (elemBitWidth % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
int64_t elemByteSize = elemBitWidth / 8;

// Default memory space is SLM.
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));

auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());

Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, basePtrStruct);

// Convert base pointer (ptr) to i32
Value basePtrI32 = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), basePtrLLVM);

Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
elemByteSize);

// convert base pointer (i32) to LLVM pointer type
basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);

if (op.getSubgroupBlockIoAttr()) {
// if the attribute 'subgroup_block_io' is set to true, it lowers to
// xevm.blockload

Type intElemTy = rewriter.getIntegerType(elemBitWidth);
VectorType intVecTy =
VectorType::get(valOrResVecTy.getShape(), intElemTy);

if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
Value loadOp =
xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
if (intVecTy != valOrResVecTy) {
loadOp =
vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
}
rewriter.replaceOp(op, loadOp);
} else {
Value dataToStore = adaptor.getData();
if (valOrResVecTy != intVecTy) {
dataToStore =
vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
}
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
nullptr);
rewriter.eraseOp(op);
}
return success();
}

if (valOrResVecTy.getNumElements() >= 1) {
auto chipOpt = xegpu::getChipStr(op);
if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
// the lowering for chunk load only works for pvc and bmg
return rewriter.notifyMatchFailure(
op, "The lowering is specific to pvc or bmg.");
}
}

if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
// if the size of valOrResVecTy is 1, it lowers to a scalar load/store
// operation. LLVM load/store does not support vector of size 1, so we
// need to handle this case separately.
auto scalarTy = valOrResVecTy.getElementType();
LLVM::LoadOp loadOp;
if (valOrResVecTy.getNumElements() == 1)
loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
else
loadOp =
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
rewriter.replaceOp(op, loadOp);
} else {
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
rewriter.eraseOp(op);
}
return success();
}
};

class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -545,8 +703,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
op, "Expected element type bit width to be multiple of 8.");
elemByteSize = elemBitWidth / 8;
}
basePtrI64 =
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
elemByteSize);
}
}
// Default memory space is global.
Expand Down Expand Up @@ -785,6 +943,13 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
// Convert MemDescType into flattened MemRefType for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
Type elemTy = type.getElementType();
int numElems = type.getNumElements();
return MemRefType::get(numElems, elemTy, AffineMap(), 3);
});

typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
return IntegerType::get(&getContext(), 64);
Expand Down Expand Up @@ -919,6 +1084,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
CreateMemDescOpPattern, MemDescSubviewOpPattern>(
typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}
Loading