Skip to content

Commit 80b4462

Browse files
committed
add chunk_size and use XeGPU_offsetType
1 parent abc84c7 commit 80b4462

File tree

4 files changed

+105
-41
lines changed

4 files changed

+105
-41
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
1616
include "mlir/Interfaces/ShapedOpInterfaces.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
1818
include "mlir/Interfaces/ViewLikeInterface.td"
19+
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1920

2021
// Base class for dialect operations. This operation inherits from the base
2122
// `Op` class in OpBase.td, and provides:
@@ -638,18 +639,39 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
638639

639640
}];
640641

641-
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
642+
let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
643+
Optional<XeGPU_OffsetType>: $offsets,
642644
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
643645
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
644646
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
645647

646648
let extraClassDeclaration = extraBaseClassDeclaration # [{
649+
Type getSourceType() {
650+
return getSource().getType();
651+
}
652+
653+
Value getTensorDesc() {
654+
return getSource();
655+
}
656+
647657
xegpu::TensorDescType getTensorDescType() {
648-
return getTensorDesc().getType();
658+
return dyn_cast<xegpu::TensorDescType>(getSourceType());
649659
}
650660
}];
651661

652-
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
662+
let assemblyFormat = [{
663+
$source
664+
(`,` $offsets^)?
665+
prop-dict
666+
attr-dict `:` type($source) (`,` type($offsets)^)?
667+
}];
668+
669+
let builders = [
670+
OpBuilder<(ins "Value": $source,
671+
"xegpu::CachePolicyAttr": $l1_hint,
672+
"xegpu::CachePolicyAttr": $l2_hint,
673+
"xegpu::CachePolicyAttr": $l3_hint)>
674+
];
653675

654676
let hasVerifier = 1;
655677
}
@@ -702,6 +724,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
702724
Variadic<Index>: $offsets,
703725
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
704726
XeGPU_MaskType: $mask,
727+
OptionalAttr<I64Attr>: $chunk_size,
705728
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
706729
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
707730
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -713,6 +736,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
713736
return getSource().getType();
714737
}
715738

739+
Value getTensorDesc() {
740+
return getSource();
741+
}
742+
716743
xegpu::TensorDescType getTensorDescType() {
717744
return dyn_cast<xegpu::TensorDescType>(getSourceType());
718745
}
@@ -733,25 +760,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
733760
}];
734761

735762
let assemblyFormat = [{
736-
$source `,`
737-
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
763+
$source ``
764+
custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
738765
$mask prop-dict
739766
attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
740767
}];
741768

742-
// let builders = [
743-
// OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
744-
// "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
745-
// "xegpu::CachePolicyAttr": $l1_hint,
746-
// "xegpu::CachePolicyAttr": $l2_hint,
747-
// "xegpu::CachePolicyAttr": $l3_hint)>
748-
// ];
769+
let builders = [
770+
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
771+
"xegpu::CachePolicyAttr": $l1_hint,
772+
"xegpu::CachePolicyAttr": $l2_hint,
773+
"xegpu::CachePolicyAttr": $l3_hint)>
774+
];
749775

750776
let hasVerifier = 1;
751777
}
752778

753779
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
754-
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
780+
AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]>
755781
]> {
756782
let summary = "store data to scattered memory locations.";
757783
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
@@ -791,15 +817,26 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
791817

792818
let arguments = (ins
793819
XeGPU_ValueType: $value,
794-
XeGPU_TensorDesc: $TensorDesc,
820+
XeGPU_TensorDesc_or_MemRef: $dest,
821+
Variadic<Index>: $offsets,
822+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
795823
XeGPU_MaskType: $mask,
824+
OptionalAttr<I64Attr>: $chunk_size,
796825
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
797826
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
798827
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
799828

800829
let extraClassDeclaration = extraBaseClassDeclaration # [{
830+
Type getDestType() {
831+
return getDest().getType();
832+
}
833+
834+
Value getTensorDesc() {
835+
return getDest();
836+
}
837+
801838
xegpu::TensorDescType getTensorDescType() {
802-
return getTensorDesc().getType();
839+
return dyn_cast<xegpu::TensorDescType>(getDestType());
803840
}
804841

805842
VectorType getValueType() {
@@ -811,8 +848,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
811848
}
812849
}];
813850

814-
let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
815-
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
851+
let assemblyFormat = [{
852+
$value `,`
853+
$dest ``
854+
custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
855+
$mask
856+
prop-dict
857+
attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask)
858+
}];
859+
860+
let builders = [
861+
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
862+
"xegpu::CachePolicyAttr": $l1_hint,
863+
"xegpu::CachePolicyAttr": $l2_hint,
864+
"xegpu::CachePolicyAttr": $l3_hint)>
865+
];
816866

817867
let hasVerifier = 1;
818868
}

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ LogicalResult CreateDescOp::verify() {
644644
//===----------------------------------------------------------------------===//
645645
LogicalResult PrefetchOp::verify() {
646646
auto tdescTy = getTensorDescType();
647-
if (!tdescTy.isScattered())
647+
if (tdescTy && !tdescTy.isScattered())
648648
return emitOpError("Expects a scattered TensorDesc.\n");
649649

650650
if (!isReadHintOrNone(getL1HintAttr()))
@@ -659,6 +659,13 @@ LogicalResult PrefetchOp::verify() {
659659
return success();
660660
}
661661

662+
void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
663+
xegpu::CachePolicyAttr l1_hint,
664+
xegpu::CachePolicyAttr l2_hint,
665+
xegpu::CachePolicyAttr l3_hint) {
666+
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
667+
}
668+
662669
//===----------------------------------------------------------------------===//
663670
// XeGPU_LoadGatherOp
664671
//===----------------------------------------------------------------------===//
@@ -680,6 +687,15 @@ LogicalResult LoadGatherOp::verify() {
680687
[&]() { return emitOpError(); });
681688
}
682689

690+
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
691+
Type valueType, Value source, Value mask,
692+
xegpu::CachePolicyAttr l1_hint,
693+
xegpu::CachePolicyAttr l2_hint,
694+
xegpu::CachePolicyAttr l3_hint) {
695+
build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(),
696+
mask, IntegerAttr(), l1_hint, l2_hint, l3_hint);
697+
}
698+
683699
//===----------------------------------------------------------------------===//
684700
// XeGPU_StoreScatterOp
685701
//===----------------------------------------------------------------------===//
@@ -701,6 +717,15 @@ LogicalResult StoreScatterOp::verify() {
701717
[&]() { return emitOpError(); });
702718
}
703719

720+
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
721+
Value value, Value dest, Value mask,
722+
xegpu::CachePolicyAttr l1_hint,
723+
xegpu::CachePolicyAttr l2_hint,
724+
xegpu::CachePolicyAttr l3_hint) {
725+
build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask,
726+
IntegerAttr(), l1_hint, l2_hint, l3_hint);
727+
}
728+
704729
//===----------------------------------------------------------------------===//
705730
// XeGPU_UpdateOffsetOp
706731
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
502502
SmallVector<Type> convertedTdescTypes =
503503
getUnrolledTypes(tdescTy, *targetShape);
504504
SmallVector<Value> convertedTdescs = pack(
505-
op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
505+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
506506

507507
SmallVector<Type> convertedMaskTypes;
508508
SmallVector<Value> convertedMasks;

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
130130
gpu.return
131131
}
132132

133-
// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
134-
gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
135-
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
136-
%1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
137-
// CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
138-
xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
139-
gpu.return
140-
}
141-
142133
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
143134
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
144135
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -339,19 +330,8 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
339330
gpu.return
340331
}
341332

342-
// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
343-
gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
344-
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
345-
%1 = arith.constant dense<1.0>: vector<32xf16>
346-
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
347-
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
348-
// CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
349-
xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
350-
gpu.return
351-
}
352-
353-
// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
354-
gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
333+
// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) {
334+
gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) {
355335
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
356336
%1 = arith.constant dense<1.0>: vector<32xf16>
357337
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
@@ -658,6 +638,15 @@ gpu.func @prefetch(%src: ui64) {
658638
}
659639

660640

641+
// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
642+
gpu.func @prefetch_offset(%src: ui64) {
643+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
644+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
645+
// CHECK: xegpu.prefetch %[[arg0]], %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
646+
xegpu.prefetch %src, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
647+
gpu.return
648+
}
649+
661650
// CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
662651
gpu.func @create_update_tdesc(%src: ui64) {
663652
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

0 commit comments

Comments
 (0)