@@ -956,9 +956,10 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
956956//
957957// Make Tensor Descriptor Op
958958//
959- def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
960- [Pure,
961- SameVariadicOperandSize]> {
959+ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
960+ Pure,
961+ SameVariadicOperandSize,
962+ ]> {
962963 let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";
963964
964965 let description = [{
@@ -969,23 +970,38 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
969970 let arguments = (ins
970971 TT_Ptr:$base,
971972 Variadic<I32>:$shape,
972- Variadic<I64>:$strides,
973- DenseI32ArrayAttr:$tensorShape
973+ Variadic<I64>:$strides
974974 );
975975
976- // TODO(peterbell10): define a custom IR type to represent descriptors
977- let results = (outs TT_Ptr:$result);
976+ let results = (outs TT_TensorDescType:$result);
978977
979978 let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
980979
981980 let builders = [
982- OpBuilder<(ins
983- "Value":$base,
984- "ValueRange":$shape,
985- "ValueRange":$strides,
986- "ArrayRef<int32_t>":$tensorShape
987- )>
981+ OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape)>
988982 ];
983+
984+ let extraClassDeclaration = [{
985+ ArrayRef<int64_t> getTensorShape() {
986+ return getType().getBlockType().getShape();
987+ }
988+ }];
989+ }
990+
991+ def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> {
992+ let summary = "Reinterpret a pointer as a tensor descriptor";
993+
994+ let description = [{
995+ This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
996+ Ideally, we can remove this once the APIs are fully fleshed out.
997+ }];
998+
999+ let arguments = (ins TT_Ptr:$rawDesc);
1000+ let results = (outs TT_TensorDescType:$result);
1001+
1002+ let assemblyFormat = [{
1003+ $rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
1004+ }];
9891005}
9901006
9911007// The following ops, including `call`, `func`, and `return` are copied and modified from
@@ -1195,20 +1211,19 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
11951211}
11961212
11971213
1198- def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
1199- MemoryEffects<[MemRead<GlobalMemory>]>]> {
1214+ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
12001215 let summary = "Load from descriptor";
12011216 let description = [{
12021217 This operation will be lowered to Nvidia TMA load operation on targets supporting it.
1203- `desc_ptr ` is a pointer to the TMA descriptor allocated in global memory .
1218+ `desc ` is a tensor descriptor object .
12041219 The destination tensor type and shape must match the descriptor otherwise the result is undefined.
12051220
12061221 This is an escape hatch and is only there for testing/experimenting.
12071222 This op will be removed in the future.
12081223 }];
12091224 let arguments = (
12101225 ins
1211- TT_PtrType:$desc_ptr ,
1226+ TT_TensorDescType:$desc ,
12121227 Variadic<I32>:$indices,
12131228 DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
12141229 DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
@@ -1217,36 +1232,37 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
12171232 let results = (outs TT_Tensor:$result);
12181233
12191234 let assemblyFormat = [{
1220- $desc_ptr `[` $indices `]`
1235+ $desc `[` $indices `]`
12211236 oilist(
12221237 `cacheModifier` `=` $cache |
12231238 `evictionPolicy` `=` $evict
12241239 )
1225- attr-dict `:` qualified(type($desc_ptr )) `->` type($result)
1240+ attr-dict `:` qualified(type($desc )) `->` type($result)
12261241 }];
12271242}
12281243
12291244def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
1230- MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
1245+ MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1246+ ]> {
12311247 let summary = "store value based on descriptor";
12321248 let description = [{
12331249 This operation will be lowered to Nvidia TMA store operation on targets supporting it.
1234- `desc_ptr ` is a pointer to the TMA descriptor allocated in global memory .
1250+ `desc ` is a tensor descriptor object .
12351251 The shape and types of `src` must match the descriptor otherwise the result is undefined.
12361252
12371253 This is an escape hatch and is only there for testing/experimenting.
12381254 This op will be removed in the future.
12391255 }];
12401256 let arguments = (
12411257 ins
1242- TT_PtrType:$desc_ptr ,
1258+ TT_TensorDescType:$desc ,
12431259 TT_Tensor:$src,
12441260 Variadic<I32>:$indices
12451261 );
12461262
12471263 let assemblyFormat = [{
1248- $desc_ptr `[` $indices `]` `,` $src
1249- attr-dict `:` qualified(type($desc_ptr )) `,` type($src)
1264+ $desc `[` $indices `]` `,` $src
1265+ attr-dict `:` qualified(type($desc )) `,` type($src)
12501266 }];
12511267}
12521268
0 commit comments