@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
146146  ];
147147}
148148
149+ //===----------------------------------------------------------------------===//
150+ // ROCDL vector types definitions
151+ //===----------------------------------------------------------------------===//
152+ 
153+ class ROCDL_ConcreteVector<Type elem, int length> :
154+   FixedVectorOfLengthAndType<[length], [elem]>,
155+   BuildableType<
156+     "::mlir::VectorType::get({" # length # "} ,"
157+       # elem.builderCall # ")">;
158+ 
159+ def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
160+ def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
161+ def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
162+ def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
163+ def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
164+ def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
165+ def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
166+ def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
167+ def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
168+ def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
169+ def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
170+ def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
171+ def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
172+ def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
173+ def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
174+ def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
175+ def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
176+ def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
177+ 
149178//===----------------------------------------------------------------------===//
150179// Wave-level primitives
151180//===----------------------------------------------------------------------===//
@@ -663,6 +692,68 @@ def ROCDL_GlobalLoadLDSOp :
663692  }];
664693}
665694
695+ //===---------------------------------------------------------------------===//
696+ // Tensor load/store intrinsics (available in GFX1250)
697+ //===---------------------------------------------------------------------===//
698+ 
699+ // Base class for tensor load/store operations with 4 descriptor groups.
700+ class ROCDL_TensorLDSIntrOp<string mnemonic> :
701+   ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
702+   dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
703+                   ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3,
704+                   I32Attr:$cachePolicy);
705+   let arguments = !con(args, baseArgs);
706+   let summary = "Base class for ROCDL tensor load/store to/from LDS.";
707+   let description = [{
708+     Moves tiles of tensor data between global memory and LDS. The tile is
709+     described by the $dgroup descriptors. 4 $dgroup descriptors allows for
710+     movement of up to 5D tensors. $cachePolicy describes the memory scope and an
711+     indicator of expected data re-use.
712+ 
713+     This op is for gfx1250+ architectures.
714+   }];
715+   let assemblyFormat = [{
716+     attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
717+   }];
718+   let extraClassDefinition = [{
719+     SmallVector<Value> $cppClass::getAccessedOperands() {
720+       return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()};
721+     }
722+   }];
723+ }
724+ 
725+ // Base class for tensor load/store operations with 2 descriptor groups
726+ // (D2 variant).
727+ class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
728+   ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
729+   dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
730+                   I32Attr:$cachePolicy);
731+   let arguments = !con(args, baseArgs);
732+   let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
733+   let description = [{
734+     Moves tiles of tensor data between global memory and LDS. The tile is
735+     described by the $dgroup descriptors. 2 $dgroup descriptors allows for
736+     movement of up to 2D tensors. $cachePolicy describes the memory scope and an
737+     indicator of expected data re-use.
738+ 
739+     This op is for gfx1250+ architectures.
740+   }];
741+   let assemblyFormat = [{
742+     attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
743+   }];
744+   let extraClassDefinition = [{
745+     SmallVector<Value> $cppClass::getAccessedOperands() {
746+       return {getDgroup0(), getDgroup1()};
747+     }
748+   }];
749+ }
750+ 
751+ // Tensor load and store operations
752+ def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">;
753+ def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">;
754+ def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">;
755+ def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">;
756+ 
666757//===---------------------------------------------------------------------===//
667758// Operations on raw buffer resources (stride of 0, bounds checks either off or in
668759// raw buffer mode).
@@ -932,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
9321023  }];
9331024}
9341025
935- class ROCDL_ConcreteVector<Type elem, int length> :
936-   FixedVectorOfLengthAndType<[length], [elem]>,
937-   BuildableType<
938-     "::mlir::VectorType::get({" # length # "} ,"
939-       # elem.builderCall # ")">;
940- 
941- def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
942- def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
943- def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
944- def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
945- def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
946- def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
947- def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
948- def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
949- def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
950- def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
951- def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
952- def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
953- def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
954- def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
955- def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
956- def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
957- def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
958- 
9591026//===---------------------------------------------------------------------===//
9601027// 16-bit float intrinsics
9611028//===---------------------------------------------------------------------===//
0 commit comments