Skip to content

Commit 59f7ea9

Browse files
committed
add optional offsets to load_gather
1 parent 3578c1b commit 59f7ea9

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
655655
}
656656

657657
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
658-
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
658+
AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
659659
]> {
660660
let summary = "load a set of scattered data points from memory.";
661661

@@ -698,16 +698,21 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
698698

699699
}];
700700

701-
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
701+
let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
702702
XeGPU_MaskType: $mask,
703703
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
704704
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
705705
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
706706
let results = (outs XeGPU_ValueType: $value);
707707

708708
let extraClassDeclaration = extraBaseClassDeclaration # [{
709+
710+
Type getSourceType() {
711+
return getSource().getType();
712+
}
713+
709714
xegpu::TensorDescType getTensorDescType() {
710-
return getTensorDesc().getType();
715+
return dyn_cast<xegpu::TensorDescType>(getSourceType());
711716
}
712717

713718
mlir::Type getElementType() {
@@ -725,8 +730,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
725730

726731
}];
727732

728-
let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
729-
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
733+
let assemblyFormat = [{$source `,` $mask prop-dict attr-dict
734+
`:` qualified(type($source)) `,` type($mask) `->` type($value)}];
730735

731736
let hasVerifier = 1;
732737
}

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
189189
let genVerifyDecl = 1;
190190
}
191191

192+
def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>;
192193

193194
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
194195
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
502502
SmallVector<Type> convertedTdescTypes =
503503
getUnrolledTypes(tdescTy, *targetShape);
504504
SmallVector<Value> convertedTdescs = pack(
505-
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
505+
op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
506506

507507
SmallVector<Type> convertedMaskTypes;
508508
SmallVector<Value> convertedMasks;

0 commit comments

Comments
 (0)