@@ -16,6 +16,7 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
1616include "mlir/Interfaces/ShapedOpInterfaces.td"
1717include "mlir/Interfaces/SideEffectInterfaces.td"
1818include "mlir/Interfaces/ViewLikeInterface.td"
19+ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1920
2021// Base class for dialect operations. This operation inherits from the base
2122// `Op` class in OpBase.td, and provides:
@@ -638,18 +639,39 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
638639
639640 }];
640641
641- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
642+ let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
643+ Optional<XeGPU_OffsetType>: $offsets,
642644 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
643645 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
644646 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
645647
646648 let extraClassDeclaration = extraBaseClassDeclaration # [{
649+ Type getSourceType() {
650+ return getSource().getType();
651+ }
652+
653+ Value getTensorDesc() {
654+ return getSource();
655+ }
656+
647657 xegpu::TensorDescType getTensorDescType() {
648- return getTensorDesc().getType( );
658+ return dyn_cast<xegpu::TensorDescType>(getSourceType() );
649659 }
650660 }];
651661
652- let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
662+ let assemblyFormat = [{
663+ $source
664+ (`,` $offsets^)?
665+ prop-dict
666+ attr-dict `:` type($source) (`,` type($offsets)^)?
667+ }];
668+
669+ let builders = [
670+ OpBuilder<(ins "Value": $source,
671+ "xegpu::CachePolicyAttr": $l1_hint,
672+ "xegpu::CachePolicyAttr": $l2_hint,
673+ "xegpu::CachePolicyAttr": $l3_hint)>
674+ ];
653675
654676 let hasVerifier = 1;
655677}
@@ -702,6 +724,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
702724 Variadic<Index>: $offsets,
703725 OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
704726 XeGPU_MaskType: $mask,
727+ OptionalAttr<I64Attr>: $chunk_size,
705728 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
706729 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
707730 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -713,6 +736,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
713736 return getSource().getType();
714737 }
715738
739+ Value getTensorDesc() {
740+ return getSource();
741+ }
742+
716743 xegpu::TensorDescType getTensorDescType() {
717744 return dyn_cast<xegpu::TensorDescType>(getSourceType());
718745 }
@@ -733,25 +760,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
733760 }];
734761
735762 let assemblyFormat = [{
736- $source `, `
737- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
763+ $source ``
764+ custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
738765 $mask prop-dict
739766 attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
740767 }];
741768
742- // let builders = [
743- // OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
744- // "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
745- // "xegpu::CachePolicyAttr": $l1_hint,
746- // "xegpu::CachePolicyAttr": $l2_hint,
747- // "xegpu::CachePolicyAttr": $l3_hint)>
748- // ];
769+ let builders = [
770+ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
771+ "xegpu::CachePolicyAttr": $l1_hint,
772+ "xegpu::CachePolicyAttr": $l2_hint,
773+ "xegpu::CachePolicyAttr": $l3_hint)>
774+ ];
749775
750776 let hasVerifier = 1;
751777}
752778
753779def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
754- AllElementTypesMatch<["value", "TensorDesc "]>, MemoryEffects<[MemWrite]>
780+ AllElementTypesMatch<["value", "dest "]>, MemoryEffects<[MemWrite]>
755781 ]> {
756782 let summary = "store data to scattered memory locations.";
757783 let description = [{ It (aka. store) stores data to scattered memory locations. The value is
@@ -791,15 +817,26 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
791817
792818 let arguments = (ins
793819 XeGPU_ValueType: $value,
794- XeGPU_TensorDesc: $TensorDesc,
820+ XeGPU_TensorDesc_or_MemRef: $dest,
821+ Variadic<Index>: $offsets,
822+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
795823 XeGPU_MaskType: $mask,
824+ OptionalAttr<I64Attr>: $chunk_size,
796825 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
797826 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
798827 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
799828
800829 let extraClassDeclaration = extraBaseClassDeclaration # [{
830+ Type getDestType() {
831+ return getDest().getType();
832+ }
833+
834+ Value getTensorDesc() {
835+ return getDest();
836+ }
837+
801838 xegpu::TensorDescType getTensorDescType() {
802- return getTensorDesc().getType( );
839+ return dyn_cast<xegpu::TensorDescType>(getDestType() );
803840 }
804841
805842 VectorType getValueType() {
@@ -811,8 +848,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
811848 }
812849 }];
813850
814- let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
815- `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
851+ let assemblyFormat = [{
852+ $value `,`
853+ $dest ``
854+ custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
855+ $mask
856+ prop-dict
857+ attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask)
858+ }];
859+
860+ let builders = [
861+ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
862+ "xegpu::CachePolicyAttr": $l1_hint,
863+ "xegpu::CachePolicyAttr": $l2_hint,
864+ "xegpu::CachePolicyAttr": $l3_hint)>
865+ ];
816866
817867 let hasVerifier = 1;
818868}
0 commit comments