@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
146146 ];
147147}
148148
149+ //===----------------------------------------------------------------------===//
150+ // ROCDL vector types definitions
151+ //===----------------------------------------------------------------------===//
152+
153+ class ROCDL_ConcreteVector<Type elem, int length> :
154+ FixedVectorOfLengthAndType<[length], [elem]>,
155+ BuildableType<
156+ "::mlir::VectorType::get({" # length # "} ,"
157+ # elem.builderCall # ")">;
158+
159+ def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
160+ def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
161+ def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
162+ def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
163+ def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
164+ def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
165+ def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
166+ def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
167+ def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
168+ def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
169+ def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
170+ def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
171+ def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
172+ def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
173+ def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
174+ def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
175+ def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
176+ def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
177+
149178//===----------------------------------------------------------------------===//
150179// Wave-level primitives
151180//===----------------------------------------------------------------------===//
@@ -663,6 +692,68 @@ def ROCDL_GlobalLoadLDSOp :
663692 }];
664693}
665694
695+ //===---------------------------------------------------------------------===//
696+ // Tensor load/store intrinsics (available in GFX1250)
697+ //===---------------------------------------------------------------------===//
698+
699+ // Base class for tensor load/store operations with 4 descriptor groups.
700+ class ROCDL_TensorLDSIntrOp<string mnemonic> :
701+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
702+ dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
703+ ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3,
704+ I32Attr:$cachePolicy);
705+ let arguments = !con(args, baseArgs);
706+ let summary = "Base class for ROCDL tensor load/store to/from LDS.";
707+ let description = [{
708+ Moves tiles of tensor data between global memory and LDS. The tile is
709+ described by the $dgroup descriptors. 4 $dgroup descriptors allows for
710+ movement of up to 5D tensors. $cachePolicy describes the memory scope and an
711+ indicator of expected data re-use.
712+
713+ This op is for gfx1250+ architectures.
714+ }];
715+ let assemblyFormat = [{
716+ attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
717+ }];
718+ let extraClassDefinition = [{
719+ SmallVector<Value> $cppClass::getAccessedOperands() {
720+ return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()};
721+ }
722+ }];
723+ }
724+
725+ // Base class for tensor load/store operations with 2 descriptor groups
726+ // (D2 variant).
727+ class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
728+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
729+ dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
730+ I32Attr:$cachePolicy);
731+ let arguments = !con(args, baseArgs);
732+ let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
733+ let description = [{
734+ Moves tiles of tensor data between global memory and LDS. The tile is
735+ described by the $dgroup descriptors. 2 $dgroup descriptors allows for
736+ movement of up to 2D tensors. $cachePolicy describes the memory scope and an
737+ indicator of expected data re-use.
738+
739+ This op is for gfx1250+ architectures.
740+ }];
741+ let assemblyFormat = [{
742+ attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
743+ }];
744+ let extraClassDefinition = [{
745+ SmallVector<Value> $cppClass::getAccessedOperands() {
746+ return {getDgroup0(), getDgroup1()};
747+ }
748+ }];
749+ }
750+
751+ // Tensor load and store operations
752+ def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">;
753+ def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">;
754+ def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">;
755+ def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">;
756+
666757//===---------------------------------------------------------------------===//
667758// Operations on raw buffer resources (stride of 0, bounds checks either off or in
668759// raw buffer mode).
@@ -932,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
9321023 }];
9331024}
9341025
935- class ROCDL_ConcreteVector<Type elem, int length> :
936- FixedVectorOfLengthAndType<[length], [elem]>,
937- BuildableType<
938- "::mlir::VectorType::get({" # length # "} ,"
939- # elem.builderCall # ")">;
940-
941- def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
942- def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
943- def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
944- def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
945- def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
946- def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
947- def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
948- def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
949- def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
950- def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
951- def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
952- def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
953- def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
954- def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
955- def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
956- def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
957- def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
958-
9591026//===---------------------------------------------------------------------===//
9601027// 16-bit float intrinsics
9611028//===---------------------------------------------------------------------===//
0 commit comments