Skip to content

Commit 920d006

Browse files
committed
removing mlir ptr dialect and ptr types in codes
1 parent 0edb877 commit 920d006

File tree

9 files changed

+164
-162
lines changed

9 files changed

+164
-162
lines changed

include/triton-shared/Dialect/TPtr/IR/TPtrDialect.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
#include "mlir/Bytecode/BytecodeOpInterface.h"
55
#include "mlir/Interfaces/SideEffectInterfaces.h" // Required for IR/TPtrOps.h.inc
66

7-
#include "mlir/Dialect/Ptr/IR/PtrDialect.h" // Required for IR/TPtrOps.h.inc
8-
#include "mlir/Dialect/Ptr/IR/PtrTypes.h" // Required for IR/TPtrOps.h.inc
7+
// #include "mlir/Dialect/Ptr/IR/PtrDialect.h" // Required for IR/TPtrOps.h.inc
8+
// #include "mlir/Dialect/Ptr/IR/PtrTypes.h" // Required for IR/TPtrOps.h.inc
99

1010
//===----------------------------------------------------------------------===//
1111
// Temporary Pointer Dialect Operations

include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33

44
include "mlir/IR/OpBase.td"
55
include "mlir/Interfaces/SideEffectInterfaces.td"
6-
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
7-
include "mlir/Dialect/Ptr/IR/PtrDialect.td"
6+
// include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
7+
// include "mlir/Dialect/Ptr/IR/PtrDialect.td"
88
include "mlir/IR/AttrTypeBase.td"
99
include "mlir/IR/BuiltinAttributeInterfaces.td"
1010
include "mlir/IR/BuiltinTypeInterfaces.td"
11+
include "triton/Dialect/Triton/IR/TritonDialect.td"
12+
include "triton/Dialect/Triton/IR/TritonTypes.td"
13+
1114

1215
def TPtr_Dialect : Dialect {
1316
let name = "tptr";
@@ -24,8 +27,8 @@ def TPtr_Dialect : Dialect {
2427
void registerTypes();
2528
}];
2629

27-
let dependentDialects = [
28-
"mlir::ptr::PtrDialect"
30+
let dependentDialects = [
31+
"mlir::triton::TritonDialect"
2932
];
3033

3134
let useDefaultAttributePrinterParser = 1;
@@ -47,10 +50,10 @@ class TPtr_Attr<string name, string _mnemonic,
4750
// Memory space attr is required for building Ptr ops
4851
// This acts as default memory space since there is
4952
// no such default implemented upstream
50-
def DefaultMemorySpaceAttr
51-
: TPtr_Attr<"DefaultMemorySpace", "default_memory_space",
52-
[DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>]> {
53-
}
53+
// def DefaultMemorySpaceAttr
54+
// : TPtr_Attr<"DefaultMemorySpace", "default_memory_space",
55+
// [DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>]> {
56+
// }
5457

5558
//
5659
// Op Base
@@ -72,7 +75,7 @@ def TPTR_IntToPtrOp : TPTR_Op<"inttoptr", [
7275
```
7376
}];
7477
let arguments = (ins AnySignlessIntegerOrIndex:$arg);
75-
let results = (outs Ptr_PtrType:$res);
78+
let results = (outs TT_PtrLike:$res);
7679
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
7780
}
7881

@@ -88,45 +91,61 @@ def TPTR_PtrToIntOp : TPTR_Op<"ptrtoint", [
8891
%int = ptr.ptrtoint %ptr : !ptr.ptr<1 : i32> to i32
8992
```
9093
}];
91-
let arguments = (ins Ptr_PtrType:$arg);
94+
let arguments = (ins TT_PtrLike:$arg);
9295
let results = (outs AnySignlessIntegerOrIndex:$res);
9396
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
9497
}
9598

96-
def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
97-
let summary = "Creates a type offset constant.";
98-
let description = [{
99-
The `addr.type_offset` operation produces an int or index-typed SSA value
100-
equal to a target-specific constant representing the offset of a single
101-
element of the given type. The default return type is `index`.
102-
Example:
103-
104-
```mlir
105-
%0 = addr.type_offset f32
106-
%1 = addr.type_offset memref<12 x f64> : i32
107-
```
108-
}];
109-
110-
let arguments = (ins TypeAttr:$baseType);
111-
let results = (outs AnySignlessIntegerOrIndex:$result);
112-
let builders = [
113-
OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy)>
114-
];
115-
let assemblyFormat = [{
116-
attr-dict $baseType custom<IntType>(type($result))
117-
}];
118-
let hasFolder = 1;
119-
}
99+
// def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
100+
// let summary = "Creates a type offset constant.";
101+
// let description = [{
102+
// The `addr.type_offset` operation produces an int or index-typed SSA value
103+
// equal to a target-specific constant representing the offset of a single
104+
// element of the given type. The default return type is `index`.
105+
// Example:
106+
107+
// ```mlir
108+
// %0 = addr.type_offset f32
109+
// %1 = addr.type_offset memref<12 x f64> : i32
110+
// ```
111+
// }];
112+
113+
// let arguments = (ins TypeAttr:$baseType);
114+
// let results = (outs AnySignlessIntegerOrIndex:$result);
115+
// let builders = [
116+
// OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy), [{
117+
// Type resultType = resultTy ? resultTy : $_builder.getIndexType();
118+
// $_state.addAttribute("baseType", baseType);
119+
// $_state.addTypes(resultType);
120+
// }]>,
121+
// OpBuilder<(ins "Type":$baseType), [{
122+
// TypeAttr baseTypeAttr = TypeAttr::get(baseType);
123+
// Type resultType = $_builder.getIndexType();
124+
// $_state.addAttribute("baseType", baseTypeAttr);
125+
// $_state.addTypes(resultType);
126+
// }]>
127+
// ];
128+
// let assemblyFormat = [{
129+
// attr-dict $baseType custom<IntType>(type($result))
130+
// }];
131+
// let extraClassDeclaration = [{
132+
// /// Returns the type offset according to `layout`. If `layout` is `nullopt`
133+
// /// the nearest layout the op will be used for the computation.
134+
// /// ! I copied this from ptr dialect
135+
// llvm::TypeSize getTypeSize(std::optional<DataLayout> layout = std::nullopt);
136+
// }];
137+
// let hasFolder = 1;
138+
// }
120139

121140
def TPTR_FromMemrefOp : TPTR_Op<"from_memref", [Pure]> {
122141
let arguments = (ins AnyMemRef:$input);
123-
let results = (outs Ptr_PtrType:$result);
142+
let results = (outs TT_PtrLike:$result);
124143
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
125144
}
126145

127146
def TPTR_ToMemrefOp : TPTR_Op<"to_memref", [
128147
Pure ]> {
129-
let arguments = (ins Ptr_PtrType:$arg);
148+
let arguments = (ins TT_PtrLike:$arg);
130149
let results = (outs AnyStaticShapeMemRef:$res);
131150
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
132151
}
@@ -143,8 +162,8 @@ def TPTR_PtrAddOp : TPTR_Op<"ptradd", [Pure, AllTypesMatch<["base", "result"]>]>
143162
```
144163
}];
145164

146-
let arguments = (ins Ptr_PtrType:$base, AnySignlessIntegerOrIndex:$offset);
147-
let results = (outs Ptr_PtrType:$result);
165+
let arguments = (ins TT_PtrLike:$base, AnySignlessIntegerOrIndex:$offset);
166+
let results = (outs TT_PtrLike:$result);
148167
let assemblyFormat = "$base $offset attr-dict `:` type($base) `,` type($offset) `to` type($result)";
149168
}
150169

lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2626
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2727
#include "mlir/Dialect/MemRef/IR/MemRef.h"
28-
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
2928
#include "mlir/Pass/PassManager.h"
3029
#include "mlir/Transforms/Passes.h"
3130

lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,17 @@
1616

1717
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
1818
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19-
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
19+
// #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
2020
#include "mlir/IR/Builders.h"
2121
#include "mlir/IR/BuiltinDialect.h"
2222
#include "mlir/IR/BuiltinOps.h"
2323
#include "mlir/IR/BuiltinTypes.h"
2424
#include "mlir/IR/ValueRange.h"
2525
#include "mlir/Pass/PassManager.h"
2626
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27-
#include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h"
28-
29-
#include "triton/Dialect/Triton/IR/Types.h"
30-
3127
#include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h"
3228
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
29+
#include "triton/Dialect/Triton/IR/Types.h"
3330

3431
using namespace mlir;
3532
using namespace triton;
@@ -87,17 +84,18 @@ struct FromMemrefConverter
8784
auto output = op.getResult(0);
8885
auto outType = output.getType();
8986

90-
if (unrankedInput && isa<triton::PointerType, ptr::PtrType>(outType)) {
87+
if (unrankedInput && isa<triton::PointerType /*, ptr::PtrType*/>(outType)) {
9188
// from_memref only takes ranked memref, cast the unranked memref to
9289
// ranked memref first.
9390
auto rankedMemref = rewriter.create<memref::CastOp>(
9491
op.getLoc(), MemRefType::get({1}, unrankedInput.getElementType()),
9592
input);
9693
auto memrefToPtr = rewriter.create<tptr::FromMemrefOp>(
9794
op->getLoc(),
98-
ptr::PtrType::get(
99-
rewriter.getContext(),
100-
tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())),
95+
// ptr::PtrType::get(
96+
// rewriter.getContext(),
97+
// tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())),
98+
triton::PointerType::get(unrankedInput.getElementType(), 1),
10199
rankedMemref);
102100

103101
rewriter.replaceAllUsesWith(output, memrefToPtr);
@@ -123,7 +121,7 @@ struct ToMemrefConverter : public OpRewritePattern<UnrealizedConversionCastOp> {
123121
auto inType = input.getType();
124122
auto output = op.getResult(0);
125123
auto outUnrankedMemrefType = dyn_cast<UnrankedMemRefType>(output.getType());
126-
if (isa<triton::PointerType, ptr::PtrType>(inType) &&
124+
if (isa<triton::PointerType/*, ptr::PtrType*/>(inType) &&
127125
outUnrankedMemrefType) {
128126
// to_memref can only cast to ranked static shape memref, we have to cast
129127
// the resulting memref back to unranked

lib/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimentalPass.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
//
66
//===----------------------------------------------------------------------===//
77

8-
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
98
#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
109
#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h"
1110
#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h"
@@ -46,7 +45,7 @@ class TritonToLinalgExperimentalPass
4645
scf::SCFDialect, tensor::TensorDialect,
4746
bufferization::BufferizationDialect, memref::MemRefDialect,
4847
ttx::TritonTilingExtDialect, tts::TritonStructuredDialect,
49-
tptr::TPtrDialect, ptr::PtrDialect>();
48+
tptr::TPtrDialect>();
5049
}
5150

5251
void runOnOperation() override {

0 commit comments

Comments
 (0)