33
44include "mlir/IR/OpBase.td"
55include "mlir/Interfaces/SideEffectInterfaces.td"
6- include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
7- include "mlir/Dialect/Ptr/IR/PtrDialect.td"
6+ // include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
7+ // include "mlir/Dialect/Ptr/IR/PtrDialect.td"
88include "mlir/IR/AttrTypeBase.td"
99include "mlir/IR/BuiltinAttributeInterfaces.td"
1010include "mlir/IR/BuiltinTypeInterfaces.td"
11+ include "triton/Dialect/Triton/IR/TritonDialect.td"
12+ include "triton/Dialect/Triton/IR/TritonTypes.td"
13+
1114
1215def TPtr_Dialect : Dialect {
1316 let name = "tptr";
@@ -24,8 +27,8 @@ def TPtr_Dialect : Dialect {
2427 void registerTypes();
2528 }];
2629
27- let dependentDialects = [
28- "mlir::ptr::PtrDialect "
30+ let dependentDialects = [
31+ "mlir::triton::TritonDialect "
2932 ];
3033
3134 let useDefaultAttributePrinterParser = 1;
@@ -47,10 +50,10 @@ class TPtr_Attr<string name, string _mnemonic,
4750// Memory space attr is required for building Ptr ops
4851// This acts as default memory space since there is
4952// no such default implemented upstream
50- def DefaultMemorySpaceAttr
51- : TPtr_Attr<"DefaultMemorySpace", "default_memory_space",
52- [DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>]> {
53- }
53+ // def DefaultMemorySpaceAttr
54+ // : TPtr_Attr<"DefaultMemorySpace", "default_memory_space",
55+ // [DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>]> {
56+ // }
5457
5558//
5659// Op Base
@@ -72,7 +75,7 @@ def TPTR_IntToPtrOp : TPTR_Op<"inttoptr", [
7275 ```
7376 }];
7477 let arguments = (ins AnySignlessIntegerOrIndex:$arg);
75- let results = (outs Ptr_PtrType :$res);
78+ let results = (outs TT_PtrLike :$res);
7679 let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
7780}
7881
@@ -88,45 +91,61 @@ def TPTR_PtrToIntOp : TPTR_Op<"ptrtoint", [
8891 %int = ptr.ptrtoint %ptr : !ptr.ptr<1 : i32> to i32
8992 ```
9093 }];
91- let arguments = (ins Ptr_PtrType :$arg);
94+ let arguments = (ins TT_PtrLike :$arg);
9295 let results = (outs AnySignlessIntegerOrIndex:$res);
9396 let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
9497}
9598
96- def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
97- let summary = "Creates a type offset constant.";
98- let description = [{
99- The `addr.type_offset` operation produces an int or index-typed SSA value
100- equal to a target-specific constant representing the offset of a single
101- element of the given type. The default return type is `index`.
102- Example:
103-
104- ```mlir
105- %0 = addr.type_offset f32
106- %1 = addr.type_offset memref<12 x f64> : i32
107- ```
108- }];
109-
110- let arguments = (ins TypeAttr:$baseType);
111- let results = (outs AnySignlessIntegerOrIndex:$result);
112- let builders = [
113- OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy)>
114- ];
115- let assemblyFormat = [{
116- attr-dict $baseType custom<IntType>(type($result))
117- }];
118- let hasFolder = 1;
119- }
99+ // def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
100+ // let summary = "Creates a type offset constant.";
101+ // let description = [{
102+ // The `addr.type_offset` operation produces an int or index-typed SSA value
103+ // equal to a target-specific constant representing the offset of a single
104+ // element of the given type. The default return type is `index`.
105+ // Example:
106+
107+ // ```mlir
108+ // %0 = addr.type_offset f32
109+ // %1 = addr.type_offset memref<12 x f64> : i32
110+ // ```
111+ // }];
112+
113+ // let arguments = (ins TypeAttr:$baseType);
114+ // let results = (outs AnySignlessIntegerOrIndex:$result);
115+ // let builders = [
116+ // OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy), [{
117+ // Type resultType = resultTy ? resultTy : $_builder.getIndexType();
118+ // $_state.addAttribute("baseType", baseType);
119+ // $_state.addTypes(resultType);
120+ // }]>,
121+ // OpBuilder<(ins "Type":$baseType), [{
122+ // TypeAttr baseTypeAttr = TypeAttr::get(baseType);
123+ // Type resultType = $_builder.getIndexType();
124+ // $_state.addAttribute("baseType", baseTypeAttr);
125+ // $_state.addTypes(resultType);
126+ // }]>
127+ // ];
128+ // let assemblyFormat = [{
129+ // attr-dict $baseType custom<IntType>(type($result))
130+ // }];
131+ // let extraClassDeclaration = [{
132+ // /// Returns the type offset according to `layout`. If `layout` is `nullopt`
133+ // /// the nearest layout the op will be used for the computation.
134+ // /// ! I copied this from ptr dialect
135+ // llvm::TypeSize getTypeSize(std::optional<DataLayout> layout = std::nullopt);
136+ // }];
137+ // let hasFolder = 1;
138+ // }
120139
121140def TPTR_FromMemrefOp : TPTR_Op<"from_memref", [Pure]> {
122141 let arguments = (ins AnyMemRef:$input);
123- let results = (outs Ptr_PtrType :$result);
142+ let results = (outs TT_PtrLike :$result);
124143 let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
125144}
126145
127146def TPTR_ToMemrefOp : TPTR_Op<"to_memref", [
128147 Pure ]> {
129- let arguments = (ins Ptr_PtrType :$arg);
148+ let arguments = (ins TT_PtrLike :$arg);
130149 let results = (outs AnyStaticShapeMemRef:$res);
131150 let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
132151}
@@ -143,8 +162,8 @@ def TPTR_PtrAddOp : TPTR_Op<"ptradd", [Pure, AllTypesMatch<["base", "result"]>]>
143162 ```
144163 }];
145164
146- let arguments = (ins Ptr_PtrType :$base, AnySignlessIntegerOrIndex:$offset);
147- let results = (outs Ptr_PtrType :$result);
165+ let arguments = (ins TT_PtrLike :$base, AnySignlessIntegerOrIndex:$offset);
166+ let results = (outs TT_PtrLike :$result);
148167 let assemblyFormat = "$base $offset attr-dict `:` type($base) `,` type($offset) `to` type($result)";
149168}
150169
0 commit comments