Skip to content

Commit f899e27

Browse files
Merge commit '6ca2dda9bdd331d007d6fab342db5a85f9b23c7d'
2 parents 37866dc + 6ca2dda commit f899e27

File tree

20 files changed

+94
-104
lines changed

20 files changed

+94
-104
lines changed

bin/CMakeLists.txt

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
2-
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
31
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
42

53
add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
@@ -8,8 +6,6 @@ add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
86
llvm_update_compile_flags(triton-opt)
97
target_link_libraries(triton-opt PRIVATE
108
TritonIntelLLVMIR
11-
${dialect_libs}
12-
${conversion_libs}
139
${triton_libs}
1410
# tests
1511
TritonTestAnalysis
@@ -19,6 +15,8 @@ target_link_libraries(triton-opt PRIVATE
1915
# MLIR core
2016
MLIROptLib
2117
MLIRPass
18+
MLIRRegisterAllDialects
19+
MLIRRegisterAllPasses
2220
MLIRTransforms
2321
)
2422

@@ -29,8 +27,6 @@ mlir_check_all_link_libraries(triton-reduce)
2927

3028
llvm_update_compile_flags(triton-reduce)
3129
target_link_libraries(triton-reduce PRIVATE
32-
${dialect_libs}
33-
${conversion_libs}
3430
${triton_libs}
3531
# tests
3632
TritonTestAnalysis
@@ -40,6 +36,8 @@ target_link_libraries(triton-reduce PRIVATE
4036
# MLIR core
4137
MLIRReduceLib
4238
MLIRPass
39+
MLIRRegisterAllDialects
40+
MLIRRegisterAllPasses
4341
MLIRTransforms
4442
)
4543

@@ -49,8 +47,6 @@ add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED)
4947

5048
llvm_update_compile_flags(triton-lsp)
5149
target_link_libraries(triton-lsp PRIVATE
52-
${dialect_libs}
53-
${conversion_libs}
5450
${triton_libs}
5551
# tests
5652
TritonTestAnalysis
@@ -60,6 +56,8 @@ target_link_libraries(triton-lsp PRIVATE
6056
# MLIR core
6157
MLIRLspServerLib
6258
MLIRPass
59+
MLIRRegisterAllDialects
60+
MLIRRegisterAllPasses
6361
MLIRTransforms
6462
)
6563

@@ -91,10 +89,11 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
9189
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
9290
target_link_libraries(triton-tensor-layout PRIVATE
9391
${triton_libs}
94-
${conversion_libs}
95-
${dialect_libs}
9692
TritonTestAnalysis
9793
TritonTestDialect
9894
TritonTestProton
9995
TritonAMDGPUTestAnalysis
96+
MLIRRegisterAllDialects
97+
MLIRRegisterAllPasses
98+
MLIRTransforms
10099
)

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
570885128351868c1308bb22e8ca351d318bc4a1
1+
bc773632355b3cebde350b0341624e88be40b744

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,6 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
100100
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
101101
);
102102

103-
let builders = [
104-
OpBuilder<(ins "Value":$src, "Value":$result,
105-
"triton::CacheModifier":$cache,
106-
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
107-
];
108-
109103
let results = (outs TTG_AsyncToken:$token);
110104

111105
let extraClassDeclaration = [{
@@ -395,9 +389,6 @@ def TTG_MaskOp: TTG_Op<"mask",
395389
let arguments = (ins I1:$pred);
396390
let results = (outs Variadic<AnyType>:$result);
397391
let regions = (region SizedRegion<1>:$region);
398-
let builders = [
399-
OpBuilder<(ins "Value":$pred)>,
400-
];
401392
}
402393

403394
def TTG_MaskReturnOp: TTG_Op<"mask.return",

lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ Value generateMaskFromOffsetRanges(OpBuilder &builder, const Location &loc,
166166

167167
// Compare with lower bound
168168
Value lowerBound = builder.create<mlir::arith::ConstantIntOp>(
169-
loc, 0, builder.getI64Type());
169+
loc, builder.getI64Type(), 0);
170170
Value splatLowerBound = builder.create<triton::SplatOp>(
171171
loc, offsetWithRange.getType(), lowerBound);
172172
Value cmpLower = builder.create<arith::CmpIOp>(

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ struct RewritedInfo {
135135

136136
// Compare with lower bound
137137
Value lowerBound = builder.create<mlir::arith::ConstantIntOp>(
138-
loc, 0, builder.getI64Type());
138+
loc, builder.getI64Type(), 0);
139139
Value splatLowerBound = builder.create<triton::SplatOp>(
140140
loc, offsetWithRange.getType(), lowerBound);
141141
Value cmpLower = builder.create<arith::CmpIOp>(

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ static void expandLoops(ModuleOp moduleOp) {
6767
if (isEpilogue) {
6868
// Return false for the predicate of the peeled iteration
6969
return rewriter.create<mlir::arith::ConstantIntOp>(
70-
predOp.getLoc(), 0, predOp.getResult().getType());
70+
predOp.getLoc(), predOp.getResult().getType(), 0);
7171
} else {
7272
if (predOp.getStage() == predOp.getMaxStage() - 1) {
7373
return rewriter.create<mlir::arith::ConstantIntOp>(
74-
predOp.getLoc(), 1, predOp.getResult().getType());
74+
predOp.getLoc(), predOp.getResult().getType(), 1);
7575
} else {
7676
OpBuilder::InsertionGuard guard(rewriter);
7777
rewriter.setInsertionPoint(op);

python/src/ir.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -790,53 +790,53 @@ void init_triton_ir(py::module &&m) {
790790
.def("get_int1",
791791
[](TritonOpBuilder &self, bool v) -> Value {
792792
return Value(self.create<arith::ConstantIntOp>(
793-
v, self.getBuilder().getI1Type()));
793+
self.getBuilder().getI1Type(), v));
794794
})
795795
.def("get_int8",
796796
[](TritonOpBuilder &self, int64_t v) -> Value {
797797
return Value(self.create<arith::ConstantIntOp>(
798-
v, self.getBuilder().getI8Type()));
798+
self.getBuilder().getI8Type(), v));
799799
})
800800
.def("get_int16",
801801
[](TritonOpBuilder &self, int64_t v) -> Value {
802802
return Value(self.create<arith::ConstantIntOp>(
803-
v, self.getBuilder().getI16Type()));
803+
self.getBuilder().getI16Type(), v));
804804
})
805805
.def("get_int32",
806806
[](TritonOpBuilder &self, int64_t v) -> Value {
807807
return Value(self.create<arith::ConstantIntOp>(
808-
v, self.getBuilder().getI32Type()));
808+
self.getBuilder().getI32Type(), v));
809809
})
810810
.def("get_int64",
811811
[](TritonOpBuilder &self, int64_t v) -> Value {
812812
return Value(self.create<arith::ConstantIntOp>(
813-
v, self.getBuilder().getI64Type()));
813+
self.getBuilder().getI64Type(), v));
814814
})
815815
.def("get_uint8",
816816
[](TritonOpBuilder &self, uint64_t v) -> Value {
817817
return Value(self.create<arith::ConstantIntOp>(
818-
v, self.getBuilder().getI8Type()));
818+
self.getBuilder().getI8Type(), v));
819819
})
820820
.def("get_uint16",
821821
[](TritonOpBuilder &self, uint64_t v) -> Value {
822822
return Value(self.create<arith::ConstantIntOp>(
823-
v, self.getBuilder().getI16Type()));
823+
self.getBuilder().getI16Type(), v));
824824
})
825825
.def("get_uint32",
826826
[](TritonOpBuilder &self, uint64_t v) -> Value {
827827
return Value(self.create<arith::ConstantIntOp>(
828-
v, self.getBuilder().getI32Type()));
828+
self.getBuilder().getI32Type(), v));
829829
})
830830
.def("get_uint64",
831831
[](TritonOpBuilder &self, uint64_t v) -> Value {
832832
return Value(self.create<arith::ConstantIntOp>(
833-
v, self.getBuilder().getI64Type()));
833+
self.getBuilder().getI64Type(), v));
834834
})
835835
.def("get_bf16",
836836
[](TritonOpBuilder &self, float v) -> Value {
837837
auto type = self.getBuilder().getBF16Type();
838838
return self.create<arith::ConstantFloatOp>(
839-
APFloat(type.getFloatSemantics(), std::to_string(v)), type);
839+
type, APFloat(type.getFloatSemantics(), std::to_string(v)));
840840
})
841841
.def("get_fp16",
842842
[](TritonOpBuilder &self, float v) -> Value {
@@ -857,17 +857,17 @@ void init_triton_ir(py::module &&m) {
857857
[](TritonOpBuilder &self, Type type) -> Value {
858858
if (auto floatTy = dyn_cast<FloatType>(type))
859859
return self.create<arith::ConstantFloatOp>(
860-
APFloat(floatTy.getFloatSemantics(), 0), floatTy);
860+
floatTy, APFloat(floatTy.getFloatSemantics(), 0));
861861
else if (auto intTy = dyn_cast<IntegerType>(type))
862-
return self.create<arith::ConstantIntOp>(0, intTy);
862+
return self.create<arith::ConstantIntOp>(intTy, 0);
863863
else
864864
throw std::runtime_error("Not implemented");
865865
})
866866
.def("get_all_ones_value",
867867
[](TritonOpBuilder &self, Type type) -> Value {
868868
uint64_t val = 0xFFFFFFFFFFFFFFFF;
869869
if (auto intTy = dyn_cast<IntegerType>(type))
870-
return self.create<arith::ConstantIntOp>(val, intTy);
870+
return self.create<arith::ConstantIntOp>(intTy, val);
871871
else
872872
throw std::runtime_error("Not implemented");
873873
})

python/src/llvm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ void init_triton_llvm(py::module &&m) {
320320
ModuleAnalysisManager mam;
321321

322322
if (arch.empty()) {
323-
llvm::TargetLibraryInfoImpl TLII;
323+
llvm::TargetLibraryInfoImpl TLII(mod->getTargetTriple());
324324
TLII.disableAllFunctions();
325325
fam.registerPass([TLII = std::move(TLII)] {
326326
return llvm::TargetLibraryAnalysis(TLII);

python/test/unit/language/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1628,7 +1628,7 @@ def kernel(X, Z):
16281628
# atom.add.bf16 is unsupported prior to Hopper so instead we generate an
16291629
# atom.cas add loop on Ampere and prior
16301630
if dst_type == 'bfloat16' and torch.cuda.get_device_capability()[0] < 9:
1631-
assert f"atom.{sem_str}.global.cas" in h.asm["ptx"]
1631+
assert f"atom.{sem_str}.gpu.global.cas" in h.asm["ptx"]
16321632
return
16331633

16341634
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]

python/test/unit/language/test_line_info.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -317,25 +317,26 @@ def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr):
317317
# CHECK: #loc = loc("{{.*}}":316:0)
318318
# CHECK-LABEL: tt.func public @kernel_basic(
319319
# CHECK-SAME: %src: !tt.ptr<f32> loc("src"(#loc)), %N: i32 loc("N"(#loc)))
320-
# CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<16xf32> loc(#loc1)
321-
# CHECK: %c16_i32 = arith.constant 16 : i32 loc(#loc1)
322-
# CHECK: %pid = tt.get_program_id x : i32 loc(#loc14)
323-
# CHECK: %offset = arith.muli %pid, %c16_i32 : i32 loc(#loc15)
324-
# CHECK: %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc16)
325-
# CHECK: %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> loc(#loc17)
326-
# CHECK: %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc17)
327-
# CHECK: %load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>> loc(#loc18)
328-
# CHECK: %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32> loc(#loc18)
329-
# CHECK: %mask = tt.splat %N : i32 -> tensor<16xi32> loc(#loc19)
330-
# CHECK: %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc19)
331-
# CHECK: %x_plus_1 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc20)
332-
# CHECK: %x_plus_1_4 = arith.addf %x_plus_1, %cst : tensor<16xf32> loc(#loc21)
333-
# CHECK: tt.store %load_src_store_dst_2, %x_plus_1_4, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc10)
320+
# CHECK: %x_plus_1 = arith.constant dense<1.000000e+00> : tensor<16xf32> loc(#loc14)
321+
# CHECK: %c16_i32 = arith.constant 16 : i32 loc(#loc2)
322+
# CHECK: %pid = tt.get_program_id x : i32 loc(#loc15)
323+
# CHECK: %offset = arith.muli %pid, %c16_i32 : i32 loc(#loc16)
324+
# CHECK: %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc17)
325+
# CHECK: %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> loc(#loc18)
326+
# CHECK: %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc18)
327+
# CHECK: %load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>> loc(#loc19)
328+
# CHECK: %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32> loc(#loc19)
329+
# CHECK: %mask = tt.splat %N : i32 -> tensor<16xi32> loc(#loc20)
330+
# CHECK: %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc20)
331+
# CHECK: %x_plus_1_4 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc21)
332+
# CHECK: %x_plus_1_5 = arith.addf %x_plus_1_4, %x_plus_1 : tensor<16xf32> loc(#loc14)
333+
# CHECK: tt.store %load_src_store_dst_2, %x_plus_1_5, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc10)
334334
# CHECK: tt.return loc(#loc11)
335-
# CHECK: } loc(#loc)
335+
# CHECK: } loc(#loc)
336+
# CHECK: } loc(#loc)
336337

337-
# CHECK: #loc1 = loc(unknown)
338-
# CHECK: #loc2 = loc({{.*}})
338+
# CHECK: #loc1 = loc({{.*}})
339+
# CHECK: #loc2 = loc(unknown)
339340
# CHECK: #loc3 = loc({{.*}})
340341
# CHECK: #loc4 = loc({{.*}})
341342
# CHECK: #loc5 = loc({{.*}})
@@ -345,13 +346,13 @@ def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr):
345346
# CHECK: #loc9 = loc({{.*}})
346347
# CHECK: #loc10 = loc({{.*}})
347348
# CHECK: #loc11 = loc({{.*}})
348-
# CHECK: #loc14 = loc("pid"(#loc2))
349-
# CHECK: #loc15 = loc("offset"(#loc3))
350-
# CHECK: #loc16 = loc("offsets"(#loc4))
349+
# CHECK: #loc14 = loc("x_plus_1"(#loc1))
350+
# CHECK: #loc15 = loc("pid"(#loc3))
351+
# CHECK: #loc16 = loc("offset"(#loc4))
351352
# CHECK: #loc17 = loc("offsets"(#loc5))
352-
# CHECK: #loc18 = loc("load_src_store_dst"(#loc6))
353-
# CHECK: #loc19 = loc("mask"(#loc7))
354-
# CHECK: #loc20 = loc("x_plus_1"(#loc8))
353+
# CHECK: #loc18 = loc("offsets"(#loc6))
354+
# CHECK: #loc19 = loc("load_src_store_dst"(#loc7))
355+
# CHECK: #loc20 = loc("mask"(#loc8))
355356
# CHECK: #loc21 = loc("x_plus_1"(#loc9))
356357

357358
pid = tl.program_id(0)
@@ -459,20 +460,20 @@ def kernel_basic_while(N):
459460
# CHECK: %arange = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
460461
arange = tl.arange(0, 16)
461462
ivar = 0
462-
# CHECK: %ivar:2 = scf.while (%arange_0 = %arange, %ivar_1 = %c0_i32) : (tensor<16xi32>, i32) -> (tensor<16xi32>, i32)
463-
# CHECK: %[[COND:.*]] = arith.cmpi slt, %ivar_1, %N : i32
464-
# CHECK: scf.condition(%[[COND]]) %arange_0, %ivar_1 : tensor<16xi32>, i32
463+
# CHECK: %ivar_[[IV0:.+]]:2 = scf.while (%arange_[[AR0:.+]] = %arange, %ivar_[[IV1:.+]] = %ivar) : (tensor<16xi32>, i32) -> (tensor<16xi32>, i32)
464+
# CHECK: %[[COND:.*]] = arith.cmpi slt, %ivar_[[IV1]], %N : i32
465+
# CHECK: scf.condition(%[[COND]]) %arange_[[AR0]], %ivar_[[IV1]] : tensor<16xi32>, i32
465466
while ivar < N:
466-
# CHECK: ^bb0(%arange_0: tensor<16xi32> loc("arange"), %ivar_1: i32
467+
# CHECK: ^bb0(%arange_[[AR0]]: tensor<16xi32> loc("arange"), %ivar_[[IV1]]: i32
467468

468-
# CHECK: %ivar_2 = arith.addi %ivar_1, %c1_i32 : i32
469+
# CHECK: %ivar_[[IV2:.+]] = arith.addi %ivar_[[IV1]], %c1_i32 : i32
469470
ivar += 1
470-
# CHECK: %arange_3 = tt.splat %ivar_2 : i32 -> tensor<16xi32>
471-
# CHECK: %arange_4 = arith.muli %arange_0, %arange_3 : tensor<16xi32>
472-
# CHECK: scf.yield %arange_4, %ivar_2 : tensor<16xi32>, i32
471+
# CHECK: %arange_[[AR1:.+]] = tt.splat %ivar_[[IV2]] : i32 -> tensor<16xi32>
472+
# CHECK: %arange_[[AR2:.+]] = arith.muli %arange_[[AR0]], %arange_[[AR1]] : tensor<16xi32>
473+
# CHECK: scf.yield %arange_[[AR2]], %ivar_[[IV2]] : tensor<16xi32>, i32
473474
arange *= ivar
474475

475-
# CHECK: tt.print ": " {hex = false, isSigned = array<i32: 1>} : %ivar#0 : tensor<16xi32>
476+
# CHECK: tt.print ": " {hex = false, isSigned = array<i32: 1>} : %ivar_[[IV0]]#0 : tensor<16xi32>
476477
tl.device_print("", arange)
477478

478479
h = triton.compile(triton.compiler.ASTSource(fn=kernel_basic_while, signature={"N": "i32"}, constexprs={}))

0 commit comments

Comments
 (0)