Skip to content

Commit 6f718b7

Browse files
zhaoshizZhaoshi Zheng
andauthored
Update Triton to c75c6b034756629b891e7b2df406f634552331d5 (#223)
Trying to fix #178 This PR includes cosmetic changes due to LLVM API change, fix for link error, lit test update and adding unsupported tests in conftest.py --------- Co-authored-by: Zhaoshi Zheng <[email protected]>
1 parent 36c6551 commit 6f718b7

File tree

14 files changed

+50
-47
lines changed

14 files changed

+50
-47
lines changed

backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def parse_options(self, opts) -> Any:
167167
args.update({k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts})
168168
return CPUOptions(**args)
169169

170-
def get_codegen_implementation(self):
170+
def get_codegen_implementation(self, options):
171171
codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)}
172172
return codegen_fns
173173

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ struct MakeTensorPtrConverter
215215
SmallVector<Value> newOffsets;
216216
for (auto [offset, stride] :
217217
llvm::zip(pointerState.offsets, pointerState.strides)) {
218-
auto mulOp = rewriter.create<arith::MulIOp>(loc, offset.get<Value>(),
219-
stride.get<Value>());
218+
auto mulOp = rewriter.create<arith::MulIOp>(loc, cast<Value>(offset),
219+
cast<Value>(stride));
220220
newOffsets.push_back(mulOp.getResult());
221221
}
222222

@@ -435,7 +435,7 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
435435
Value dimi = dyn_cast<Value>(mstate.dims[i]);
436436
if (!dimi) {
437437
dimi = rewriter.create<arith::ConstantOp>(
438-
loc, cast<IntegerAttr>(mstate.dims[i].get<Attribute>()));
438+
loc, cast<IntegerAttr>(cast<Attribute>(mstate.dims[i])));
439439
}
440440

441441
auto cmpOp = rewriter.create<arith::CmpIOp>(
@@ -1236,9 +1236,10 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12361236
}
12371237

12381238
bool requiresF32Conversion(const Type elemType, Operation *redOp) const {
1239+
unsigned width =
1240+
cast<FloatType>(Float32Type::get(elemType.getContext())).getWidth();
12391241
return isa<FloatType>(elemType) &&
1240-
elemType.getIntOrFloatBitWidth() <
1241-
Float32Type::get(elemType.getContext()).getWidth() &&
1242+
elemType.getIntOrFloatBitWidth() < width &&
12421243
isa<arith::AddFOp>(redOp);
12431244
}
12441245

lib/Analysis/OpFoldResultUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
namespace mlir {
1717

1818
std::optional<int64_t> getIntAttr(const OpFoldResult ofr) {
19-
if (ofr.is<Attribute>() && isa<IntegerAttr>(ofr.get<Attribute>()))
20-
return dyn_cast<IntegerAttr>(ofr.get<Attribute>()).getInt();
19+
if (isa<Attribute>(ofr) && isa<IntegerAttr>(cast<Attribute>(ofr)))
20+
return dyn_cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
2121

2222
return std::nullopt;
2323
}
@@ -185,7 +185,7 @@ OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs,
185185

186186
// 2. if lhs is not constant
187187
assert(!lhsIntAttr);
188-
auto mulOp = b.create<arith::MulIOp>(loc, lhs.get<Value>(), rhs);
188+
auto mulOp = b.create<arith::MulIOp>(loc, cast<Value>(lhs), rhs);
189189
return mulOp.getResult();
190190
}
191191

lib/Analysis/PtrAnalysis.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -862,12 +862,12 @@ void PtrAnalysis::rewriteAdvanceOp(
862862
op.getLoc(), rewriter.getIndexAttr(0));
863863
offsetValue = constOp.getResult();
864864
} else {
865-
offsetValue = offset.get<Value>();
865+
offsetValue = cast<Value>(offset);
866866
}
867867
auto castOp = rewriter.create<arith::IndexCastOp>(
868868
loc, rewriter.getIndexType(), increment);
869869
auto mulOp = rewriter.create<arith::MulIOp>(loc, castOp.getResult(),
870-
stride.get<Value>());
870+
cast<Value>(stride));
871871
auto addOp =
872872
rewriter.create<arith::AddIOp>(loc, mulOp.getResult(), offsetValue);
873873
newOffsets.push_back(addOp.getResult());
@@ -999,15 +999,15 @@ void PtrAnalysis::rewriteYieldOp(
999999
op.getLoc(), rewriter.getIndexAttr(0));
10001000
operands.push_back(constOp.getResult());
10011001
} else {
1002-
operands.push_back(s.get<Value>());
1002+
operands.push_back(cast<Value>(s));
10031003
}
10041004
}
10051005

10061006
for (auto s : state.strides) {
10071007
assert(!getIntAttr(s) && "PtrState strides for yield within for "
10081008
"loop not expected to be "
10091009
"attribute.");
1010-
operands.push_back(s.get<Value>());
1010+
operands.push_back(cast<Value>(s));
10111011
}
10121012
}
10131013

@@ -1171,7 +1171,7 @@ void PtrAnalysis::rewriteForOp(
11711171
newInitArgs.push_back(constOp.getResult());
11721172
state.offsets[j] = constOp.getResult();
11731173
} else {
1174-
newInitArgs.push_back(s.get<Value>());
1174+
newInitArgs.push_back(cast<Value>(s));
11751175
}
11761176
}
11771177

@@ -1183,7 +1183,7 @@ void PtrAnalysis::rewriteForOp(
11831183
newInitArgs.push_back(constOp.getResult());
11841184
state.strides[j] = constOp.getResult();
11851185
} else {
1186-
newInitArgs.push_back(s.get<Value>());
1186+
newInitArgs.push_back(cast<Value>(s));
11871187
}
11881188
}
11891189

lib/AnalysisStructured/PtrAnalysis.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,12 +793,12 @@ LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) {
793793
loc, builder.getIndexAttr(offsetIntAttr.value()));
794794
offsetValue = constOp.getResult();
795795
} else {
796-
offsetValue = offset.get<Value>();
796+
offsetValue = cast<Value>(offset);
797797
}
798798
auto castOp = builder.create<arith::IndexCastOp>(
799799
loc, builder.getIndexType(), increment);
800800
auto mulOp = builder.create<arith::MulIOp>(loc, castOp.getResult(),
801-
stride.get<Value>());
801+
cast<Value>(stride));
802802
auto addOp =
803803
builder.create<arith::AddIOp>(loc, mulOp.getResult(), offsetValue);
804804
newOffsets.push_back(addOp.getResult());
@@ -1029,7 +1029,7 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
10291029
op.getLoc(), builder.getIndexAttr(sIntAttr.value()));
10301030
replacements.push_back(constOp.getResult());
10311031
} else {
1032-
replacements.push_back(s.get<Value>());
1032+
replacements.push_back(cast<Value>(s));
10331033
}
10341034
}
10351035

@@ -1040,7 +1040,7 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
10401040
op.getLoc(), builder.getIndexAttr(sIntAttr.value()));
10411041
replacements.push_back(constOp.getResult());
10421042
} else {
1043-
replacements.push_back(s.get<Value>());
1043+
replacements.push_back(cast<Value>(s));
10441044
}
10451045
}
10461046
}

lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class TritonArithToLinalgPass
8181

8282
tensor::populateDecomposeTensorConcatPatterns(patterns);
8383

84-
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
84+
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
8585
return failure();
8686
}
8787
return success();
@@ -103,7 +103,7 @@ class TritonArithToLinalgPass
103103
{
104104
RewritePatternSet patterns(&getContext());
105105
populateTritonArithToLinalgCanonicalizationPatterns(patterns);
106-
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
106+
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
107107
signalPassFailure();
108108
}
109109
}

lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class TritonToLinalgPass : public TritonToLinalgBase<TritonToLinalgPass> {
100100
{
101101
RewritePatternSet patterns(&getContext());
102102
populateTritonToLinalgCanonicalizationPatterns(patterns);
103-
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
103+
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
104104
signalPassFailure();
105105
}
106106
}

python/examples/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,17 @@ def device(request):
9191
# tt.gather not supported yet
9292
"test_gather",
9393
"test_gather_warp_shuffle",
94-
# device 'cpu' does not have 'index
94+
# device 'cpu' does not have 'index'
9595
"test_zero_strided_tensors",
9696
# hard-coded with 'ttg' attributes
9797
"test_convert_mma2mma",
9898
"test_local_load_store",
99-
"test_local_load_store_mma"
99+
"test_local_load_store_mma",
100+
"test_convert_warp_local",
101+
# hard-code to use 'cuda' device
102+
"test_scan_1d",
103+
"test_tma_load_block_shape_err",
104+
"test_tma_store_block_shape_err"
100105
}
101106

102107
# probably different version of MLIR on the nightly build machine is complaining

test/Conversion/TritonPtrToMemref/post_structured_to_memref.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ module {
1919
%subview = memref.subview %reinterpret_cast[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
2020
%subview_0 = memref.subview %alloc[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
2121
memref.copy %subview, %subview_0 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
22-
%10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32>
22+
%10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32> to tensor<1024xf32>
2323
%reinterpret_cast_1 = memref.reinterpret_cast %1 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
2424
%alloc_2 = memref.alloc() : memref<1024xf32>
2525
%subview_3 = memref.subview %reinterpret_cast_1[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
2626
%subview_4 = memref.subview %alloc_2[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
2727
memref.copy %subview_3, %subview_4 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
28-
%11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32>
28+
%11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32> to tensor<1024xf32>
2929
%12 = arith.addf %10, %11 : tensor<1024xf32>
3030
%reinterpret_cast_5 = memref.reinterpret_cast %0 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
3131
%extracted_slice = tensor.extract_slice %12[0] [%9] [1] : tensor<1024xf32> to tensor<?xf32>

test/Conversion/TritonPtrToMemref/post_triton_load_store_to_memref.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module {
1515
%8 = tt.splat %arg3 : i32 -> tensor<1024xi32>
1616
%9 = arith.cmpi slt, %7, %8 : tensor<1024xi32>
1717
%cast = memref.cast %2 : memref<*xf32> to memref<?xf32>
18-
%10 = bufferization.to_tensor %cast restrict : memref<?xf32>
18+
%10 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32>
1919
%11 = tensor.empty() : tensor<1024xf32>
2020
%12 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%11 : tensor<1024xf32>) {
2121
^bb0(%in: i32, %in_2: i1, %out: f32):
@@ -30,7 +30,7 @@ module {
3030
linalg.yield %17 : f32
3131
} -> tensor<1024xf32>
3232
%cast_0 = memref.cast %1 : memref<*xf32> to memref<?xf32>
33-
%13 = bufferization.to_tensor %cast_0 restrict : memref<?xf32>
33+
%13 = bufferization.to_tensor %cast_0 restrict : memref<?xf32> to tensor<?xf32>
3434
%14 = tensor.empty() : tensor<1024xf32>
3535
%15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%14 : tensor<1024xf32>) {
3636
^bb0(%in: i32, %in_2: i1, %out: f32):

0 commit comments

Comments
 (0)