Skip to content

Commit 554b95e

Browse files
committed
add attributes
1 parent 4c58d3d commit 554b95e

File tree

6 files changed

+253
-186
lines changed

6 files changed

+253
-186
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13021302
let arguments = (ins XeGPU_MemDesc:$mem_desc,
13031303
Variadic<Index>: $offsets,
13041304
DenseI64ArrayAttr: $const_offsets,
1305+
OptionalAttr<I32Attr>:$vec_length,
1306+
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
1307+
OptionalAttr<UnitAttr>:$subgroupBlockIO,
13051308
OptionalAttr<DistributeLayoutAttr>:$layout
13061309
);
13071310
let results = (outs XeGPU_ValueType:$res);

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
371371
Value baseAddr, Value offset, int64_t elemByteSize) {
372372
Value byteSize = arith::ConstantIntOp::create(
373373
rewriter, loc, rewriter.getI64Type(), elemByteSize);
374-
offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset);
374+
offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
375+
offset);
375376
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
376377
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
377378
return newAddr;
@@ -513,29 +514,36 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
513514
// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
514515
// 32 bits will be converted to 32 bits.
515516
class CreateMemDescOpPattern final
516-
: public OpConversionPattern<xegpu::CreateMemDescOp> {
517+
: public OpConversionPattern<xegpu::CreateMemDescOp> {
517518
public:
518519
using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
519520
LogicalResult
520521
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
521-
ConversionPatternRewriter &rewriter) const override {
522-
// DEBUG: Print operation and types
523-
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
524-
TypedValue<MemRefType> src = op.getSource();
525-
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
526-
527-
// Create the result MemRefType with the same shape, element type, and memory space
528-
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
529-
530-
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
531-
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
532-
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
533-
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
534-
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero,
535-
ValueRange());
536-
rewriter.replaceOp(op, viewOp);
537-
LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
538-
return success();
522+
ConversionPatternRewriter &rewriter) const override {
523+
// DEBUG: Print operation and types
524+
LLVM_DEBUG(llvm::dbgs()
525+
<< "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
526+
TypedValue<MemRefType> src = op.getSource();
527+
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
528+
529+
// Create the result MemRefType with the same shape, element type, and
530+
// memory space
531+
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
532+
533+
LLVM_DEBUG(llvm::dbgs()
534+
<< "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
535+
LLVM_DEBUG(llvm::dbgs()
536+
<< "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
537+
LLVM_DEBUG(llvm::dbgs()
538+
<< "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
539+
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
540+
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
541+
Value(src), zero, ValueRange());
542+
rewriter.replaceOp(op, viewOp);
543+
LLVM_DEBUG(
544+
llvm::dbgs()
545+
<< "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
546+
return success();
539547
}
540548
};
541549

@@ -551,7 +559,6 @@ class MemDescSubviewOpPattern final
551559
}
552560
};
553561

554-
555562
template <typename OpType,
556563
typename = std::enable_if_t<llvm::is_one_of<
557564
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
@@ -577,7 +584,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
577584
data = adaptor.getData();
578585
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
579586

580-
int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth();
587+
int64_t elemBitWidth =
588+
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
581589
// Element type must be multiple of 8 bits.
582590
if (elemBitWidth % 8 != 0)
583591
return rewriter.notifyMatchFailure(
@@ -589,14 +597,17 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
589597
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
590598

591599
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
592-
593-
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct);
600+
601+
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
602+
rewriter, loc, basePtrStruct);
594603

595604
// Convert base pointer (ptr) to i64
596-
Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
605+
Value basePtrI64 = arith::IndexCastUIOp::create(
606+
rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
597607

598608
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
599-
basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
609+
basePtrI64 =
610+
addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
600611

601612
// convert base pointer (i64) to LLVM pointer type
602613
basePtrLLVM =

0 commit comments

Comments
 (0)