@@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
628628 As compared to prefetch_nd, which works on non-scattered TensorDesc,
629629 it works on scattered TensorDesc instead.
630630
631- Example:
631+ Example 1 :
632632 ```mlir
633633 xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
634634 l2_hint = #xegpu.cache_hint<cached>,
635635 l3_hint = #xegpu.cache_hint<cached>}
636636 : !xegpu.tensor_desc<16xf16>
637637 ```
638+
639+ Example 2:
640+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641+ It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642+ The source operand could be a raw pointer (uint64_t).
643+ Please refer to create_tdesc for the restriction of memref.
644+ ```mlir
645+ %a = memref.alloc() : memref<1024xf32>
646+ %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
647+ xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
648+ l2_hint = #xegpu.cache_hint<cached>,
649+ l3_hint = #xegpu.cache_hint<cached>}
650+ : memref<1024xf32>, vector<4xindex>
651+ ```
638652
639653 }];
640654
641- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
655+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
656+ Optional<XeGPU_OffsetType>: $offsets,
642657 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
643658 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
644659 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
645660
646661 let extraClassDeclaration = extraBaseClassDeclaration # [{
662+ Type getSourceType() {
663+ return getSource().getType();
664+ }
665+
666+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
667+ if (auto tdescType = getTensorDescType()) {
668+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
669+ }
670+ return TypedValue<xegpu::TensorDescType>();
671+ }
672+
647673 xegpu::TensorDescType getTensorDescType() {
648- return getTensorDesc().getType( );
674+ return dyn_cast<xegpu::TensorDescType>(getSourceType() );
649675 }
650676 }];
651677
652- let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
678+ let assemblyFormat = [{
679+ $source
680+ (`[` $offsets^ `]`)?
681+ prop-dict
682+ attr-dict `:` type(operands)
683+ }];
684+
685+ let builders = [
686+ OpBuilder<(ins "Value": $source,
687+ "xegpu::CachePolicyAttr": $l1_hint,
688+ "xegpu::CachePolicyAttr": $l2_hint,
689+ "xegpu::CachePolicyAttr": $l3_hint)>
690+ ];
653691
654692 let hasVerifier = 1;
655693}
656694
657- def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
658- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
659- ]> {
695+ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
660696 let summary = "load a set of scattered data points from memory.";
661697
662698 let description = [{ It (aka. load) load data per each work-item. The output
@@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
687723 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
688724 vector<16xi1> -> vector<16x8xf32>
689725 ```
726+
690727 Example 3 (SIMT mode):
691728 ```mlir
692729 %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
695732 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
696733 vector<16xi1> -> vector<8xf32>
697734 ```
735+
736+ Example 4:
737+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738+ It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739+ The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740+ for the restriction of memref.
741+ ```mlir
742+ %a = memref.alloc() : memref<1024xf32>
743+ %offsets = vector.step : vector<16xindex>
744+ %mask = vector.constant_mask [16]: vector<16xi1>
745+ %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
746+ l2_hint = #xegpu.cache_hint<cached>,
747+ l3_hint = #xegpu.cache_hint<cached>}
748+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
749+ ```
698750
699751 }];
700752
701- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
753+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
754+ Optional<XeGPU_OffsetType>: $offsets,
702755 XeGPU_MaskType: $mask,
756+ OptionalAttr<I64Attr>: $chunk_size,
703757 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
704758 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
705759 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
706760 let results = (outs XeGPU_ValueType: $value);
707761
708762 let extraClassDeclaration = extraBaseClassDeclaration # [{
763+
764+ Type getSourceType() {
765+ return getSource().getType();
766+ }
767+
768+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
769+ if (auto tdescType = getTensorDescType()) {
770+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
771+ }
772+ return TypedValue<xegpu::TensorDescType>();
773+ }
774+
709775 xegpu::TensorDescType getTensorDescType() {
710- return getTensorDesc().getType( );
776+ return dyn_cast<xegpu::TensorDescType>(getSourceType() );
711777 }
712778
713779 mlir::Type getElementType() {
@@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
725791
726792 }];
727793
728- let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
729- `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
794+ let assemblyFormat = [{
795+ $source
796+ (`[` $offsets^ `]`)? `,`
797+ $mask prop-dict
798+ attr-dict `:` type(operands) `->` type($value)
799+ }];
800+
801+ let builders = [
802+ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803+ "xegpu::CachePolicyAttr": $l1_hint,
804+ "xegpu::CachePolicyAttr": $l2_hint,
805+ "xegpu::CachePolicyAttr": $l3_hint)>
806+ ];
730807
731808 let hasVerifier = 1;
732809}
733810
734- def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
735- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
736- ]> {
811+ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
737812 let summary = "store data to scattered memory locations.";
738813 let description = [{ It (aka. store) stores data to scattered memory locations. The value is
739814 typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
768843 l3_hint = #xegpu.cache_hint<write_through>}>
769844 : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
770845 ```
846+
847+ Example 4:
848+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849+ It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850+ The dest operand could be a raw pointer (uint64_t).
851+ Please refer to create_tdesc for the restriction of memref.
852+ ```mlir
853+ %a = memref.alloc() : memref<1024xf32>
854+ %val = arith.constant dense<0.0> : vector<16xf32>
855+ %offsets = vector.step : vector<16xindex>
856+ %mask = vector.constant_mask [16]: vector<16xi1>
857+ xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
858+ l2_hint = #xegpu.cache_hint<cached>,
859+ l3_hint = #xegpu.cache_hint<cached>}
860+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
861+ ```
862+
771863 }];
772864
773865 let arguments = (ins
774866 XeGPU_ValueType: $value,
775- XeGPU_TensorDesc: $TensorDesc,
867+ XeGPU_GatherScatterSourceType: $dest,
868+ Optional<XeGPU_OffsetType>: $offsets,
776869 XeGPU_MaskType: $mask,
870+ OptionalAttr<I64Attr>: $chunk_size,
777871 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
778872 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
779873 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
780874
781875 let extraClassDeclaration = extraBaseClassDeclaration # [{
876+ Type getDestType() {
877+ return getDest().getType();
878+ }
879+
880+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
881+ if (auto tdescType = getTensorDescType()) {
882+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
883+ }
884+ return TypedValue<xegpu::TensorDescType>();
885+ }
886+
782887 xegpu::TensorDescType getTensorDescType() {
783- return getTensorDesc().getType( );
888+ return dyn_cast<xegpu::TensorDescType>(getDestType() );
784889 }
785890
786891 VectorType getValueType() {
@@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
792897 }
793898 }];
794899
795- let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
796- `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
900+ let assemblyFormat = [{
901+ $value `,`
902+ $dest
903+ (`[` $offsets^ `]`)? `,`
904+ $mask
905+ prop-dict
906+ attr-dict `:` type(operands)
907+ }];
908+
909+ let builders = [
910+ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911+ "xegpu::CachePolicyAttr": $l1_hint,
912+ "xegpu::CachePolicyAttr": $l2_hint,
913+ "xegpu::CachePolicyAttr": $l3_hint)>
914+ ];
797915
798916 let hasVerifier = 1;
799917}
0 commit comments