Skip to content

Commit 454973f

Browse files
committed
Revert "removing mlir ptr dialect and ptr types in codes"
This reverts commit 920d006.
1 parent f000947 commit 454973f

File tree

9 files changed

+162
-164
lines changed

9 files changed

+162
-164
lines changed

include/triton-shared/Dialect/TPtr/IR/TPtrDialect.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
#include "mlir/Bytecode/BytecodeOpInterface.h"
55
#include "mlir/Interfaces/SideEffectInterfaces.h" // Required for IR/TPtrOps.h.inc
66

7-
// #include "mlir/Dialect/Ptr/IR/PtrDialect.h" // Required for IR/TPtrOps.h.inc
8-
// #include "mlir/Dialect/Ptr/IR/PtrTypes.h" // Required for IR/TPtrOps.h.inc
7+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h" // Required for IR/TPtrOps.h.inc
8+
#include "mlir/Dialect/Ptr/IR/PtrTypes.h" // Required for IR/TPtrOps.h.inc
99

1010
//===----------------------------------------------------------------------===//
1111
// Temporary Pointer Dialect Operations

include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33

44
include "mlir/IR/OpBase.td"
55
include "mlir/Interfaces/SideEffectInterfaces.td"
6-
// include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
7-
// include "mlir/Dialect/Ptr/IR/PtrDialect.td"
6+
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
7+
include "mlir/Dialect/Ptr/IR/PtrDialect.td"
88
include "mlir/IR/AttrTypeBase.td"
99
include "mlir/IR/BuiltinAttributeInterfaces.td"
1010
include "mlir/IR/BuiltinTypeInterfaces.td"
11-
include "triton/Dialect/Triton/IR/TritonDialect.td"
12-
include "triton/Dialect/Triton/IR/TritonTypes.td"
13-
1411

1512
def TPtr_Dialect : Dialect {
1613
let name = "tptr";
@@ -27,8 +24,8 @@ def TPtr_Dialect : Dialect {
2724
void registerTypes();
2825
}];
2926

30-
let dependentDialects = [
31-
"mlir::triton::TritonDialect"
27+
let dependentDialects = [
28+
"mlir::ptr::PtrDialect"
3229
];
3330

3431
let useDefaultAttributePrinterParser = 1;
@@ -50,10 +47,10 @@ class TPtr_Attr<string name, string _mnemonic,
5047
// Memory space attr is required for building Ptr ops
5148
// This acts as default memory space since there is
5249
// no such default implemented upstream
53-
// def DefaultMemorySpaceAttr
54-
// : TPtr_Attr<"DefaultMemorySpace", "default_memory_space",
55-
// [DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>]> {
56-
// }
50+
def DefaultMemorySpaceAttr
51+
: TPtr_Attr<"DefaultMemorySpace", "default_memory_space",
52+
[DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>]> {
53+
}
5754

5855
//
5956
// Op Base
@@ -75,7 +72,7 @@ def TPTR_IntToPtrOp : TPTR_Op<"inttoptr", [
7572
```
7673
}];
7774
let arguments = (ins AnySignlessIntegerOrIndex:$arg);
78-
let results = (outs TT_PtrLike:$res);
75+
let results = (outs Ptr_PtrType:$res);
7976
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
8077
}
8178

@@ -91,61 +88,45 @@ def TPTR_PtrToIntOp : TPTR_Op<"ptrtoint", [
9188
%int = ptr.ptrtoint %ptr : !ptr.ptr<1 : i32> to i32
9289
```
9390
}];
94-
let arguments = (ins TT_PtrLike:$arg);
91+
let arguments = (ins Ptr_PtrType:$arg);
9592
let results = (outs AnySignlessIntegerOrIndex:$res);
9693
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
9794
}
9895

99-
// def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
100-
// let summary = "Creates a type offset constant.";
101-
// let description = [{
102-
// The `addr.type_offset` operation produces an int or index-typed SSA value
103-
// equal to a target-specific constant representing the offset of a single
104-
// element of the given type. The default return type is `index`.
105-
// Example:
106-
107-
// ```mlir
108-
// %0 = addr.type_offset f32
109-
// %1 = addr.type_offset memref<12 x f64> : i32
110-
// ```
111-
// }];
112-
113-
// let arguments = (ins TypeAttr:$baseType);
114-
// let results = (outs AnySignlessIntegerOrIndex:$result);
115-
// let builders = [
116-
// OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy), [{
117-
// Type resultType = resultTy ? resultTy : $_builder.getIndexType();
118-
// $_state.addAttribute("baseType", baseType);
119-
// $_state.addTypes(resultType);
120-
// }]>,
121-
// OpBuilder<(ins "Type":$baseType), [{
122-
// TypeAttr baseTypeAttr = TypeAttr::get(baseType);
123-
// Type resultType = $_builder.getIndexType();
124-
// $_state.addAttribute("baseType", baseTypeAttr);
125-
// $_state.addTypes(resultType);
126-
// }]>
127-
// ];
128-
// let assemblyFormat = [{
129-
// attr-dict $baseType custom<IntType>(type($result))
130-
// }];
131-
// let extraClassDeclaration = [{
132-
// /// Returns the type offset according to `layout`. If `layout` is `nullopt`
133-
// /// the nearest layout the op will be used for the computation.
134-
// /// ! I copied this from ptr dialect
135-
// llvm::TypeSize getTypeSize(std::optional<DataLayout> layout = std::nullopt);
136-
// }];
137-
// let hasFolder = 1;
138-
// }
96+
def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
97+
let summary = "Creates a type offset constant.";
98+
let description = [{
99+
The `addr.type_offset` operation produces an int or index-typed SSA value
100+
equal to a target-specific constant representing the offset of a single
101+
element of the given type. The default return type is `index`.
102+
Example:
103+
104+
```mlir
105+
%0 = addr.type_offset f32
106+
%1 = addr.type_offset memref<12 x f64> : i32
107+
```
108+
}];
109+
110+
let arguments = (ins TypeAttr:$baseType);
111+
let results = (outs AnySignlessIntegerOrIndex:$result);
112+
let builders = [
113+
OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy)>
114+
];
115+
let assemblyFormat = [{
116+
attr-dict $baseType custom<IntType>(type($result))
117+
}];
118+
let hasFolder = 1;
119+
}
139120

140121
def TPTR_FromMemrefOp : TPTR_Op<"from_memref", [Pure]> {
141122
let arguments = (ins AnyMemRef:$input);
142-
let results = (outs TT_PtrLike:$result);
123+
let results = (outs Ptr_PtrType:$result);
143124
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
144125
}
145126

146127
def TPTR_ToMemrefOp : TPTR_Op<"to_memref", [
147128
Pure ]> {
148-
let arguments = (ins TT_PtrLike:$arg);
129+
let arguments = (ins Ptr_PtrType:$arg);
149130
let results = (outs AnyStaticShapeMemRef:$res);
150131
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
151132
}
@@ -162,8 +143,8 @@ def TPTR_PtrAddOp : TPTR_Op<"ptradd", [Pure, AllTypesMatch<["base", "result"]>]>
162143
```
163144
}];
164145

165-
let arguments = (ins TT_PtrLike:$base, AnySignlessIntegerOrIndex:$offset);
166-
let results = (outs TT_PtrLike:$result);
146+
let arguments = (ins Ptr_PtrType:$base, AnySignlessIntegerOrIndex:$offset);
147+
let results = (outs Ptr_PtrType:$result);
167148
let assemblyFormat = "$base $offset attr-dict `:` type($base) `,` type($offset) `to` type($result)";
168149
}
169150

lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2626
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2727
#include "mlir/Dialect/MemRef/IR/MemRef.h"
28+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
2829
#include "mlir/Pass/PassManager.h"
2930
#include "mlir/Transforms/Passes.h"
3031

lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
1818
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19-
// #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
19+
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
2020
#include "mlir/IR/Builders.h"
2121
#include "mlir/IR/BuiltinDialect.h"
2222
#include "mlir/IR/BuiltinOps.h"
@@ -25,9 +25,12 @@
2525
#include "mlir/Pass/PassManager.h"
2626
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2727
#include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h"
28-
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
28+
2929
#include "triton/Dialect/Triton/IR/Types.h"
3030

31+
#include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h"
32+
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
33+
3134
using namespace mlir;
3235
using namespace triton;
3336

@@ -84,18 +87,17 @@ struct FromMemrefConverter
8487
auto output = op.getResult(0);
8588
auto outType = output.getType();
8689

87-
if (unrankedInput && isa<triton::PointerType /*, ptr::PtrType*/>(outType)) {
90+
if (unrankedInput && isa<triton::PointerType, ptr::PtrType>(outType)) {
8891
// from_memref only takes ranked memref, cast the unranked memref to
8992
// ranked memref first.
9093
auto rankedMemref = rewriter.create<memref::CastOp>(
9194
op.getLoc(), MemRefType::get({1}, unrankedInput.getElementType()),
9295
input);
9396
auto memrefToPtr = rewriter.create<tptr::FromMemrefOp>(
9497
op->getLoc(),
95-
// ptr::PtrType::get(
96-
// rewriter.getContext(),
97-
// tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())),
98-
triton::PointerType::get(unrankedInput.getElementType(), 1),
98+
ptr::PtrType::get(
99+
rewriter.getContext(),
100+
tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())),
99101
rankedMemref);
100102

101103
rewriter.replaceAllUsesWith(output, memrefToPtr);
@@ -121,7 +123,7 @@ struct ToMemrefConverter : public OpRewritePattern<UnrealizedConversionCastOp> {
121123
auto inType = input.getType();
122124
auto output = op.getResult(0);
123125
auto outUnrankedMemrefType = dyn_cast<UnrankedMemRefType>(output.getType());
124-
if (isa<triton::PointerType/*, ptr::PtrType*/>(inType) &&
126+
if (isa<triton::PointerType, ptr::PtrType>(inType) &&
125127
outUnrankedMemrefType) {
126128
// to_memref can only cast to ranked static shape memref, we have to cast
127129
// the resulting memref back to unranked

lib/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimentalPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//
66
//===----------------------------------------------------------------------===//
77

8+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
89
#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
910
#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h"
1011
#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h"
@@ -45,7 +46,7 @@ class TritonToLinalgExperimentalPass
4546
scf::SCFDialect, tensor::TensorDialect,
4647
bufferization::BufferizationDialect, memref::MemRefDialect,
4748
ttx::TritonTilingExtDialect, tts::TritonStructuredDialect,
48-
tptr::TPtrDialect>();
49+
tptr::TPtrDialect, ptr::PtrDialect>();
4950
}
5051

5152
void runOnOperation() override {

0 commit comments

Comments
 (0)