Skip to content

Commit 8e507ec

Browse files
committed
Adjust to latest XeGPU dialect update.
1 parent e240e47 commit 8e507ec

File tree

3 files changed

+185
-194
lines changed

3 files changed

+185
-194
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 131 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,10 @@ class CreateDescToXeVMPattern
446446
matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
447447
ConversionPatternRewriter &rewriter) const override {
448448
auto eTy = op.getTensorDescType().getElementType();
449-
if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
450-
return rewriter.notifyMatchFailure(op,
451-
"Expected element type bit width to be multiple of 8.");
449+
auto eBw = eTy.getIntOrFloatBitWidth();
450+
if (eBw % 8 != 0) {
451+
return rewriter.notifyMatchFailure(
452+
op, "Expected element type bit width to be multiple of 8.");
452453
}
453454
auto loc = op.getLoc();
454455
// offsets are provided as scalar i64 by type converter.
@@ -458,10 +459,8 @@ class CreateDescToXeVMPattern
458459
Value addr = adaptor.getSource();
459460
// ui32 or i32 are passed as i32 so they need to be casted to i64.
460461
if (addr.getType() != rewriter.getI64Type())
461-
addr = arith::IndexCastUIOp::create(
462-
rewriter, loc, rewriter.getI64Type(), addr);
463-
auto laneAddr =
464-
addOffset(rewriter, loc, addr, offsets, getElemByteSize(op));
462+
addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
463+
auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
465464
rewriter.replaceOp(op, laneAddr);
466465
return success();
467466
}
@@ -475,16 +474,16 @@ class UpdateOffsetToXeVMPattern
475474
xegpu::UpdateOffsetOp::Adaptor adaptor,
476475
ConversionPatternRewriter &rewriter) const override {
477476
auto eTy = op.getTensorDescType().getElementType();
478-
if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
479-
return rewriter.notifyMatchFailure(op,
480-
"Expected element type bit width to be multiple of 8.");
477+
auto eBw = eTy.getIntOrFloatBitWidth();
478+
if (eBw % 8 != 0) {
479+
return rewriter.notifyMatchFailure(
480+
op, "Expected element type bit width to be multiple of 8.");
481481
}
482482
auto loc = op.getLoc();
483483
// scatter descriptor is provided as scalar i64 by type converter.
484484
// offsets are provided as scalar i64 by type converter.
485-
Value newOffset =
486-
addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
487-
getElemByteSize(op));
485+
Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
486+
adaptor.getOffsets(), eBw / 8);
488487
rewriter.replaceOp(op, newOffset);
489488
return success();
490489
}
@@ -501,12 +500,35 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
501500
auto loc = op.getLoc();
502501
auto ctxt = rewriter.getContext();
503502
auto tdescTy = op.getTensorDescType();
504-
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
505-
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
506-
if (tdescTy)
507-
ptrTypeLLVM = LLVM::LLVMPointerType::get(
508-
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
509503
Value basePtrI64;
504+
// Load result or Store valye Type can be vector or scalar.
505+
Type valOrResTy;
506+
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
507+
valOrResTy = op.getResult().getType();
508+
} else {
509+
valOrResTy = adaptor.getValue().getType();
510+
}
511+
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
512+
bool hasScalarVal = !valOrResVecTy;
513+
int64_t elemBitWidth =
514+
hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
515+
: valOrResVecTy.getElementType().getIntOrFloatBitWidth();
516+
// Element type must be multiple of 8 bits.
517+
if (elemBitWidth % 8 != 0) {
518+
return rewriter.notifyMatchFailure(
519+
op, "Expected element type bit width to be multiple of 8.");
520+
}
521+
int64_t elemByteSize = elemBitWidth / 8;
522+
// Default memory space is global.
523+
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
524+
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
525+
// If tensor descriptor is available, we use its memory space.
526+
if (tdescTy) {
527+
ptrTypeLLVM = LLVM::LLVMPointerType::get(
528+
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
529+
}
530+
// Base pointer can come from source (load) or dest (store).
531+
// If they are memrefs, we use their memory space.
510532
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
511533
basePtrI64 = adaptor.getSource();
512534
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
@@ -522,76 +544,79 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
522544
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
523545
}
524546
}
547+
// Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
525548
if (basePtrI64.getType() != rewriter.getI64Type()) {
526-
basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
527-
basePtrI64);
549+
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
550+
basePtrI64);
528551
}
529-
basePtrI64.dump();
530552
Value offsets = adaptor.getOffsets();
531-
offsets.dump();
532553
Value mask = adaptor.getMask();
533-
mask.dump();
534554
if (offsets) {
535-
if (dyn_cast<VectorType>(offsets.getType())){
536-
// Offset needs be scalar.
555+
if (dyn_cast<VectorType>(offsets.getType())) {
556+
// Offset needs be scalar. Single element vector is converted to scalar
557+
// by type converter.
537558
return rewriter.notifyMatchFailure(op,
538559
"Expected offsets to be a scalar.");
539560
} else {
561+
// If offsets are provided, we add them to the base pointer.
562+
// Offsets are in number of elements, we need to multiply by
563+
// element byte size.
540564
basePtrI64 =
541-
addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
565+
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
542566
}
543567
}
544-
basePtrI64.dump();
568+
// Convert base pointer (i64) to LLVM pointer type.
545569
Value basePtrLLVM =
546570
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
547-
basePtrLLVM.dump();
548-
VectorType srcOrDstVecTy = op.getValueType();
549-
VectorType srcOrDstFlatVecTy = VectorType::get(
550-
srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
571+
551572
Value maskForLane;
552573
VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
553574
if (maskVecTy) {
575+
// Mask needs be scalar. Single element vector is converted to scalar by
576+
// type converter.
554577
return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
555-
} else
578+
} else {
556579
maskForLane = mask;
580+
}
557581
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
558-
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy},
582+
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
559583
maskForLane, true, true);
584+
// If mask is true,- then clause - load from memory and yield.
560585
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
586+
if (!hasScalarVal)
587+
valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
588+
valOrResVecTy.getElementType());
561589
Value loaded =
562-
LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
590+
LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
591+
// Set cache control attribute on the load operation.
563592
loaded.getDefiningOp()->setAttr(
564593
"cache_control", xevm::LoadCacheControlAttr::get(
565594
ctxt, translateLoadXeGPUCacheHint(
566595
op.getL1Hint(), op.getL3Hint())));
567-
if (srcOrDstVecTy != srcOrDstFlatVecTy) {
568-
loaded =
569-
vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
570-
}
571596
scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
572597
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
573-
// If mask is false, we yield a vector of zeros.
574-
auto eTy = srcOrDstVecTy.getElementType();
575-
loaded = arith::ConstantOp::create(
576-
rewriter, loc,
577-
eTy.isFloat()
578-
? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0))
579-
: DenseElementsAttr::get(srcOrDstVecTy,
580-
IntegerAttr::get(eTy, 0)));
598+
// If mask is false - else clause -yield a vector of zeros.
599+
auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
600+
TypedAttr eVal;
601+
if (eTy.isFloat())
602+
eVal = FloatAttr::get(eTy, 0.0);
603+
else
604+
eVal = IntegerAttr::get(eTy, 0);
605+
if (hasScalarVal)
606+
loaded = arith::ConstantOp::create(rewriter, loc, eVal);
607+
else
608+
loaded = arith::ConstantOp::create(
609+
rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
581610
scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
582611
rewriter.replaceOp(op, ifOp.getResult(0));
583612
} else {
613+
// if mask is true, perform the store.
584614
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
585615
auto body = ifOp.getBody();
586616
rewriter.setInsertionPointToStart(body);
587-
VectorType valTy = op.getValue().getType();
588-
Value srcFlatVec = op.getValue();
589-
if (valTy != srcOrDstFlatVecTy) {
590-
srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
591-
srcOrDstFlatVecTy, srcFlatVec);
592-
}
593617
auto storeOp =
594-
LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
618+
LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
619+
// Set cache control attribute on the store operation.
595620
storeOp.getOperation()->setAttr(
596621
"cache_control", xevm::StoreCacheControlAttr::get(
597622
ctxt, translateStoreXeGPUCacheHint(
@@ -610,27 +635,64 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
610635
auto loc = op.getLoc();
611636
auto ctxt = rewriter.getContext();
612637
auto tdescTy = op.getTensorDescType();
613-
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
614-
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
615638
Value basePtrI64 = adaptor.getSource();
616-
Value offsets = adaptor.getOffsets();
639+
// Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
617640
if (basePtrI64.getType() != rewriter.getI64Type()) {
618-
basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
619-
basePtrI64);
641+
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
642+
basePtrI64);
620643
}
644+
Value offsets = adaptor.getOffsets();
621645
if (offsets) {
622646
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
623647
if (offsetsVecTy) {
624648
// Offset needs be scalar.
625649
return rewriter.notifyMatchFailure(op,
626650
"Expected offsets to be a scalar.");
627651
} else {
652+
int64_t elemBitWidth{0};
653+
int64_t elemByteSize;
654+
// Element byte size can come from three sources:
655+
if (tdescTy) {
656+
// If tensor descriptor is available, we use its element type to
657+
// determine element byte size.
658+
elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
659+
} else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
660+
// If memref is available, we use its element type to
661+
// determine element byte size.
662+
elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
663+
} else {
664+
// Otherwise, we use the provided offset byte alignment.
665+
elemByteSize = *op.getOffsetAlignByte();
666+
}
667+
if (elemBitWidth != 0) {
668+
if (elemBitWidth % 8 != 0) {
669+
return rewriter.notifyMatchFailure(
670+
op, "Expected element type bit width to be multiple of 8.");
671+
}
672+
elemByteSize = elemBitWidth / 8;
673+
}
628674
basePtrI64 =
629-
addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
675+
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
630676
}
631677
}
678+
// Default memory space is global.
679+
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
680+
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
681+
// If tensor descriptor is available, we use its memory space.
682+
if (tdescTy) {
683+
ptrTypeLLVM = LLVM::LLVMPointerType::get(
684+
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
685+
}
686+
// If source is a memref, we use its memory space.
687+
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
688+
auto addrSpace = memRefTy.getMemorySpaceAsInt();
689+
if (addrSpace != 0)
690+
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
691+
}
692+
// Convert base pointer (i64) to LLVM pointer type.
632693
Value ptrLLVM =
633694
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
695+
// Create the prefetch op with cache control attribute.
634696
xevm::PrefetchOp::create(
635697
rewriter, loc, ptrLLVM,
636698
xevm::LoadCacheControlAttr::get(
@@ -863,17 +925,17 @@ struct ConvertXeGPUToXeVMPass
863925
});
864926

865927
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
866-
ValueRange inputs,
867-
Location loc) -> Value {
928+
ValueRange inputs,
929+
Location loc) -> Value {
868930
if (inputs.size() != 1)
869931
return {};
870932
auto input = inputs.front();
871933
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
872934

873-
Value addr = memref::ExtractAlignedPointerAsIndexOp::create(
874-
builder, loc, input);
875-
return arith::IndexCastUIOp::create(builder, loc, type,
876-
addr).getResult();
935+
Value addr =
936+
memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
937+
return arith::IndexCastUIOp::create(builder, loc, type, addr)
938+
.getResult();
877939
}
878940
return {};
879941
};
@@ -888,7 +950,8 @@ struct ConvertXeGPUToXeVMPass
888950
Value cast =
889951
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
890952
.getResult();
891-
return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
953+
return arith::IndexCastUIOp::create(builder, loc, type, cast)
954+
.getResult();
892955
}
893956
return {};
894957
};
@@ -903,7 +966,8 @@ struct ConvertXeGPUToXeVMPass
903966
Value cast =
904967
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
905968
.getResult();
906-
return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
969+
return arith::IndexCastUIOp::create(builder, loc, type, cast)
970+
.getResult();
907971
}
908972
return {};
909973
};

0 commit comments

Comments
 (0)