Skip to content

Commit b673908

Browse files
authored
Update triton to e44bd1c83c1c3e8deac7c4f02683cfb3cc395c8b (#345)
This updates to latest Triton main which should resolve the nightly build issues. Primary changes: * Change conversion of tts::MakeGatherScatterTensorPtrOp to insert identity cast so that DialectConversion does not ignore the type converted operand of the associated tts::LoadOp when supplying the op adaptor to the pattern. * Move support of unsplat to the new dedicated operation since they do not do reduction anymore * Add missing test figure to conftest * Disable float annotation tests since bfloat16 and float16 are not supported in CPU backend * Add missing link library to registering of passes * Update construction of ValueRange{} which was causing compilation errors. * Remove unused builder in TPtrOps since it caused linker errors. * Remove bfloat16 and float16 from CPU backend since not testing is present. Better to crash during compilation than get runtime errors. * Remove GPUDialect that is no longer present upstream
1 parent 447ffe3 commit b673908

File tree

13 files changed

+68
-57
lines changed

13 files changed

+68
-57
lines changed

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
[submodule "triton"]
2-
path = triton
3-
url = https://github.com/triton-lang/triton.git

backend/driver.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def _ty_to_cpp(ty):
5454
"u16": "uint16_t",
5555
"u32": "uint32_t",
5656
"u64": "uint64_t",
57-
"fp16": "float",
58-
"bf16": "float",
57+
# Proper support for bfloat16 and float16 is not yet handled.
58+
# https://github.com/microsoft/triton-shared/issues/348
59+
# "fp16": "TODO",
60+
# "bf16": "TODO",
5961
"fp32": "float",
6062
"f32": "float",
6163
"fp64": "double",

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
360360
loc, rewriter);
361361
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
362362
auto loadOp = rewriter.create<affine::AffineLoadOp>(
363-
op.getLoc(), sMemRef, zeroMap, std::nullopt);
363+
op.getLoc(), sMemRef, zeroMap, ValueRange{});
364364
rewriter.replaceOp(op, loadOp.getResult());
365365
return success();
366366
}
@@ -520,7 +520,7 @@ struct StoreConverter : public OpConversionPattern<triton::StoreOp> {
520520
PtrAnalysis::getScalarMemRef(op.getPtr(), ptr, loc, rewriter);
521521
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
522522
rewriter.create<affine::AffineStoreOp>(loc, val, sMemRef, zeroMap,
523-
std::nullopt);
523+
ValueRange{});
524524
rewriter.eraseOp(op);
525525
return success();
526526
}
@@ -649,6 +649,28 @@ struct SplatConverter : public OpConversionPattern<triton::SplatOp> {
649649
}
650650
};
651651

652+
struct UnsplatConverter : public OpConversionPattern<triton::UnsplatOp> {
653+
using OpConversionPattern::OpConversionPattern;
654+
655+
LogicalResult
656+
matchAndRewrite(triton::UnsplatOp op, OpAdaptor adaptor,
657+
ConversionPatternRewriter &rewriter) const override {
658+
auto tensorType = op.getSrc().getType();
659+
660+
// Only generate indices for non-zero rank tensors.
661+
SmallVector<Value, 1> indices(tensorType.getRank());
662+
if (indices.size() > 0) {
663+
auto zeroIdx =
664+
rewriter.createOrFold<arith::ConstantIndexOp>(op.getLoc(), 0);
665+
llvm::fill(indices, zeroIdx);
666+
}
667+
668+
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getSrc(),
669+
indices);
670+
return success();
671+
}
672+
};
673+
652674
struct BroadcastConverter : public OpConversionPattern<triton::BroadcastOp> {
653675
private:
654676
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
@@ -1397,24 +1419,6 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
13971419
return success();
13981420
}
13991421

1400-
LogicalResult
1401-
convertToTensorExtract(triton::ReduceOp op,
1402-
typename triton::ReduceOp::Adaptor adaptor,
1403-
ConversionPatternRewriter &rewriter) const {
1404-
assert(llvm::hasSingleElement(op.getSrcs()));
1405-
1406-
auto returnOp = cast<triton::ReduceReturnOp>(*op.getOps().begin());
1407-
assert(llvm::hasSingleElement(returnOp.getResult()));
1408-
assert(cast<BlockArgument>(returnOp.getResult().front()).getArgNumber() ==
1409-
0);
1410-
1411-
auto source = op.getSrcs().front();
1412-
auto zeroIdx =
1413-
rewriter.createOrFold<arith::ConstantIndexOp>(op.getLoc(), 0);
1414-
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, source, zeroIdx);
1415-
return success();
1416-
}
1417-
14181422
public:
14191423
LogicalResult
14201424
matchAndRewrite(triton::ReduceOp op,
@@ -1431,14 +1435,6 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
14311435
"axis is within "
14321436
"operand's rank");
14331437

1434-
// Unsplat is implemented as a single element, rank 1 reduction where
1435-
// single element is yielded immediately. This can be simplified into
1436-
// a single element extract.
1437-
if (llvm::hasSingleElement(op.getOps()) && sourceType.getRank() == 1 &&
1438-
sourceType.getShape()[0] == 1) {
1439-
return convertToTensorExtract(op, adaptor, rewriter);
1440-
}
1441-
14421438
return convertToLinalgReduce(op, adaptor, rewriter);
14431439
}
14441440
};

include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
109109

110110
let arguments = (ins TypeAttr:$baseType);
111111
let results = (outs AnySignlessIntegerOrIndex:$result);
112-
let builders = [
113-
OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy)>
114-
];
115112
let assemblyFormat = [{
116113
attr-dict $baseType custom<IntType>(type($result))
117114
}];

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -577,21 +577,18 @@ struct MakeTensorPtrConverter
577577

578578
struct MakeGatherScatterTensorPtrConverter
579579
: public OpConversionPattern<tts::MakeGatherScatterTensorPtrOp> {
580-
private:
581-
using OpConversionPattern<tts::MakeGatherScatterTensorPtrOp>::OpConversionPattern;
582-
583-
public:
584-
MakeGatherScatterTensorPtrConverter(const TypeConverter &typeConverter,
585-
MLIRContext *context)
586-
: OpConversionPattern<tts::MakeGatherScatterTensorPtrOp>(typeConverter, context) {}
580+
using OpConversionPattern::OpConversionPattern;
587581

588582
LogicalResult
589583
matchAndRewrite(tts::MakeGatherScatterTensorPtrOp op, OpAdaptor adaptor,
590584
ConversionPatternRewriter &rewriter) const override {
591585
// The gatherScatterPtr is rewritten as separate rows during load/store
592586
// operations. Therefore, no action is needed here except saving
593-
// adaptor.getBase().
594-
rewriter.replaceOp(op, adaptor.getBase());
587+
// adaptor.getBase(). DialectConversion will ignore pure type conversion if
588+
// we were to simply replace the op with adaptor.getBase(). To circumvent
589+
// this we create an identity cast.
590+
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
591+
op, adaptor.getBase().getType(), adaptor.getBase());
595592
return success();
596593
}
597594
};

lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns(
7878
patterns.add<ClampConverter>(patterns.getContext());
7979
patterns.add<MatmulConverter>(patterns.getContext());
8080
patterns.add<SplatConverter>(patterns.getContext());
81+
patterns.add<UnsplatConverter>(patterns.getContext());
8182
patterns.add<DenseConstantConverter>(patterns.getContext());
8283
patterns.add<CumSumConverter>(patterns.getContext());
8384
patterns.add<ReshapeConverter>(patterns.getContext());

lib/Conversion/TritonToLinalg/TritonToLinalg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
6363
patterns.add<AssertConverter>(patterns.getContext());
6464
patterns.add<MatmulConverter>(patterns.getContext());
6565
patterns.add<SplatConverter>(patterns.getContext());
66+
patterns.add<UnsplatConverter>(patterns.getContext());
6667
patterns.add<DenseConstantConverter>(patterns.getContext());
6768
patterns.add<UnrealizedCastConverter>(patterns.getContext());
6869
patterns.add<CumSumConverter>(patterns.getContext());

lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ struct ScalarLoadConverter : public OpConversionPattern<tts::GatherOp> {
104104
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
105105

106106
auto scalarLoadOp = rewriter.create<affine::AffineLoadOp>(
107-
loc, memref, zeroMap, std::nullopt);
107+
loc, memref, zeroMap, ValueRange{});
108108

109109
rewriter.replaceOp(gatherOp, scalarLoadOp.getResult());
110110

@@ -150,7 +150,7 @@ struct ScalarStoreConverter : public OpConversionPattern<tts::ScatterOp> {
150150
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
151151

152152
rewriter.create<affine::AffineStoreOp>(loc, storeVal, memref, zeroMap,
153-
std::nullopt);
153+
ValueRange{});
154154
rewriter.eraseOp(scatterOp);
155155

156156
return success();

python/examples/conftest.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ def empty_decorator(func):
1818
def device(request):
1919
return "cpu"
2020

21+
22+
# this fixture is used for test_enable_fp_fusion
23+
@pytest.fixture
24+
def fresh_knobs():
25+
from triton._internal_testing import _fresh_knobs_impl
26+
27+
fresh_function, reset_function = _fresh_knobs_impl()
28+
try:
29+
yield fresh_function()
30+
finally:
31+
reset_function()
32+
33+
2134
# this fixture is used for test_trans_4d && test_trans_reshape
2235
@pytest.fixture
2336
def with_allocator():
@@ -32,7 +45,7 @@ def with_allocator():
3245
triton.set_allocator(NullAllocator())
3346

3447

35-
tests_supported = {
48+
core_tests_supported = {
3649
"test_store_eviction_policy",
3750
"test_unary_op",
3851
"test_umulhi",
@@ -77,6 +90,11 @@ def with_allocator():
7790
"test_arange",
7891
}
7992

93+
annotations_tests_supported = {
94+
"test_int_annotation",
95+
"test_unknown_annotation",
96+
}
97+
8098

8199
def pytest_collection_modifyitems(config, items):
82100
skip_marker = pytest.mark.skip(reason="CPU backend does not support it yet")
@@ -89,7 +107,11 @@ def pytest_collection_modifyitems(config, items):
89107
test_func_name = item.originalname if item.originalname else item.name
90108

91109
test_file = str(item.fspath)
92-
if test_file.endswith("test_core.py") and test_func_name not in tests_supported:
110+
if test_file.endswith("test_core.py") and test_func_name not in core_tests_supported:
111+
item.add_marker(skip_marker)
112+
continue
113+
114+
if test_file.endswith("test_annotations.py") and test_func_name not in annotations_tests_supported:
93115
item.add_marker(skip_marker)
94116
continue
95117

test/Conversion/TritonToLinalgExperimental/reduce_unsplat.mlir renamed to test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@ module {
66
%0 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>
77
%1 = tt.load %0 : tensor<1x!tt.ptr<i32>>
88
%2 = arith.cmpi sgt, %1, %cst : tensor<1xi32>
9-
%3 = "tt.reduce"(%2) <{axis = 0 : i32}> ({
10-
^bb0(%arg1: i1, %arg2: i1):
11-
tt.reduce.return %arg1 : i1
12-
}) : (tensor<1xi1>) -> i1
9+
%3 = tt.unsplat %2 : tensor<1xi1>
1310
scf.if %3 {
1411
tt.store %arg0, %c42_i32 : !tt.ptr<i32>
1512
}

0 commit comments

Comments
 (0)