Skip to content

Commit 0a73d54

Browse files
committed
Merge commit '6294db5a12443a49d1f0604a8de08d2b4b921497'
2 parents 8dc24ec + 6294db5 commit 0a73d54

File tree

19 files changed

+192
-148
lines changed

19 files changed

+192
-148
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10791079
LogicalResult AxisInfoAnalysis::visitOperation(
10801080
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10811081
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1082-
// TODO: For sure not the right way to do this
1083-
// but why is scf.if not initialized otherwise?
1082+
// If any operands are not yet ready, skip this operation for now.
10841083
for (auto op : operands)
10851084
if (op->getValue().getRank() == 0)
1086-
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
1085+
return success();
10871086
AxisInfo curr = visitors.apply(op, operands);
10881087
if (curr.getRank() == 0) {
10891088
setAllToEntryStates(results);
@@ -1112,9 +1111,11 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11121111
ProgramPoint *programPoint = getProgramPointAfter(op);
11131112
auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound());
11141113
auto *stepLattice = getLatticeElementFor(programPoint, op.getStep());
1115-
for (auto op_iter : {lbLattice, stepLattice})
1116-
if (op_iter->getValue().getRank() == 0)
1117-
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
1114+
// If lb or step is not yet ready, skip this operation for now.
1115+
if (lbLattice->getValue().getRank() == 0 ||
1116+
stepLattice->getValue().getRank() == 0) {
1117+
return;
1118+
}
11181119

11191120
AxisInfo::DimVectorT knownContiguity(1, 1);
11201121
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1188,24 +1189,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11881189
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11891190
&knownContiguity, &knownDivisibility,
11901191
&knownConstancy);
1191-
} else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192-
op)) {
1193-
// scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
1194-
// Control flow operations are initialized with "unknown" state:
1195-
// the maximum possible divisibility, contiguity, and constancy.
1192+
} else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1193+
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1194+
// "unknown" state: the maximum possible divisibility, contiguity, and
1195+
// constancy.
11961196
knownDivisibility = DimVectorT(rank, kMaxDivisor);
11971197
knownConstancy = DimVectorT(rank, kMaxDivisor);
11981198
knownContiguity = DimVectorT(rank, kMaxDivisor);
11991199
}
12001200
} else if (Operation *op = value.getDefiningOp()) {
1201-
if (isa<RegionBranchOpInterface>(op)) {
1202-
// scf::ForOp, scf::IfOp, scf::WhileOp
1203-
// Control flow operations are initialized with "unknown" state:
1204-
// the maximum possible divisibility, contiguity, and constancy.
1205-
knownDivisibility = DimVectorT(rank, kMaxDivisor);
1206-
knownConstancy = DimVectorT(rank, kMaxDivisor);
1207-
knownContiguity = DimVectorT(rank, kMaxDivisor);
1208-
}
12091201
// Other operations are conservatively initialized with the lowest possible
12101202
// divisibility, contiguity, and constancy unless they have specified.
12111203
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
@@ -1358,6 +1350,10 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13581350
auto *axisInfoMap = getFuncData(funcOp);
13591351
auto updateAxisInfoMap = [&](Value value) {
13601352
auto axisInfo = analysis->getLatticeElement(value)->getValue();
1353+
// If we could not determine the AxisInfo for this value, assume the
1354+
// pessimistic state.
1355+
if (axisInfo.getRank() == 0)
1356+
axisInfo = AxisInfo::getPessimisticValueState(value);
13611357
AxisInfo curAxisInfo;
13621358
if (axisInfoMap->count(value)) {
13631359
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));

lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Value createMemDescToI64(RewriterBase &rewriter, Location loc,
2727
const LLVMTypeConverter *typeConverter,
2828
ttg::MemDescType memDescTy, Value sharedMemStruct) {
2929
TritonLLVMOpBuilder b(loc, rewriter);
30-
if (isa<ttng::TensorMemoryEncodingAttr>(memDescTy.getEncoding())) {
30+
if (isa<ttng::TensorMemorySpaceAttr>(memDescTy.getMemorySpace())) {
3131
return b.ptrtoint(rewriter.getIntegerType(64), sharedMemStruct);
3232
}
3333
assert(isa<ttg::SharedEncodingTrait>(memDescTy.getEncoding()) &&

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,9 +2505,9 @@ LogicalResult DotOperandEncodingAttr::verify(
25052505
return emitError()
25062506
<< "ttg.dot_op kWidth parameter must be 4/8/16 for WMMA v2 "
25072507
"(including packed cases for `scaled_dot`)";
2508-
if (parentAttr.getVersion() == 3 && !llvm::is_contained({2, 8, 16}, kWidth))
2508+
if (parentAttr.getVersion() == 3 && kWidth == 0)
25092509
return emitError()
2510-
<< "ttg.dot_op kWidth parameter must be 2/8/16 for WMMA v3";
2510+
<< "ttg.dot_op kWidth parameter is mandatory for WMMA v3 ";
25112511
return success();
25122512
}
25132513

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class LayoutRematerialization {
127127
}
128128

129129
void cleanup();
130-
void backwardRematerialization();
130+
bool backwardRematerialization();
131131
void backwardRematerialization(ConvertLayoutOp convertOp);
132132
// TODO: Merge the three hoistConvert*(); functions as they are duplicate code
133133
void hoistConvertDotOperand();
@@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
10191019
return success();
10201020
}
10211021

1022-
void LayoutRematerialization::backwardRematerialization() {
1022+
bool LayoutRematerialization::backwardRematerialization() {
1023+
bool changed = false;
10231024
// Go through each ConvertLayoutOp.
10241025
SmallVector<ConvertLayoutOp> convertOps;
10251026
funcOp.walk(
@@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() {
10311032
// backward slices.
10321033
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
10331034
convertOp.getResult());
1035+
} else {
1036+
changed = true;
10341037
}
10351038
}
1039+
return changed;
10361040
}
10371041

10381042
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
@@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15931597
rewriteSlice(slice, layout, convertOp, mapping);
15941598
}
15951599

1596-
void backwardRematerialization(ModuleOp module) {
1597-
module.walk([](FuncOp funcOp) {
1600+
bool backwardRematerialization(ModuleOp module) {
1601+
bool changed = false;
1602+
module.walk([&](FuncOp funcOp) {
15981603
LayoutRematerialization layoutRemat(funcOp);
1599-
layoutRemat.backwardRematerialization();
1604+
changed |= layoutRemat.backwardRematerialization();
16001605
layoutRemat.cleanup();
16011606
});
1607+
return changed;
16021608
}
16031609

16041610
void hoistConvert(ModuleOp module) {
@@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass
16591665

16601666
cleanupConvertOps();
16611667

1662-
// 2. For remaining convert ops, try to rematerialize the slice of producer
1663-
// operation to avoid having to convert.
1664-
backwardRematerialization(m);
1665-
LLVM_DEBUG({
1666-
DBGS() << "Module after backward remat:\n";
1667-
m.dump();
1668-
});
1669-
1670-
// Cleanup dummy converts created during backward remat.
1671-
cleanupConvertOps();
1672-
1668+
bool changed = false;
1669+
do {
1670+
changed = false;
1671+
// 2. For remaining convert ops, try to rematerialize the slice of
1672+
// producer operation to avoid having to convert.
1673+
changed = backwardRematerialization(m);
1674+
LLVM_DEBUG({
1675+
DBGS() << "Module after backward remat:\n";
1676+
m.dump();
1677+
});
1678+
1679+
// Cleanup dummy converts created during backward remat.
1680+
cleanupConvertOps();
1681+
} while (changed);
16731682
// 3. For remaining converts, try to hoist them above cast generating larger
16741683
// size types in order to reduce the cost of the convert op.
16751684
hoistConvert(m);

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from dataclasses import dataclass
44

55
import triton
6+
from triton_kernels import target_info
67
from triton_kernels.target_info import get_cdna_version
78
from triton_kernels.tensor import FP4
89
import torch
910
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
10-
from triton_kernels.tensor import bitwidth
11+
from triton_kernels.tensor import bitwidth, get_layout
1112

1213

1314
@dataclass
@@ -297,8 +298,12 @@ def make_default_opt_flags_nvidia(
297298
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
298299
tiles_per_sm = grid_size_tma / n_sms
299300
supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
301+
requires_persistent = (get_layout(precision_config.act_scale) is not None or get_layout(precision_config.weight_scale) is not None) and target_info.has_native_mxfp()
300302
if constraints.get("is_persistent", None) is not None:
301303
is_persistent = constraints["is_persistent"]
304+
elif requires_persistent:
305+
assert supports_persistent, "persistent kernel required but not supported"
306+
is_persistent = True
302307
else:
303308
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
304309
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4

test/Analysis/test-alignment.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,3 +1089,18 @@ tt.func public @test_inductor_for() {
10891089
}
10901090
tt.return
10911091
}
1092+
1093+
// -----
1094+
1095+
// Verify that if an operation is statically determined to be dead, we fall back
1096+
// to assigning it a pessimistic value, rather than skipping it entirely.
1097+
tt.func @dead_op_pessimistic() {
1098+
%c5 = arith.constant dense<5> : tensor<4xi32>
1099+
%c7 = arith.constant dense<7> : tensor<4xi32>
1100+
%false = arith.constant false
1101+
scf.if %false {
1102+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
1103+
%add = arith.addi %c5, %c7 : tensor<4xi32>
1104+
}
1105+
tt.return
1106+
}

test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
8181
// CHECK-LABEL: async_load_multicast_to_half_ctas
8282
tt.func public @async_load_multicast_to_half_ctas(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
8383
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
84-
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
8584
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
8685
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32
8786
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
@@ -104,7 +103,6 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha
104103
tt.func public @async_load_multicast_group_of_2_strided_by_8(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
105104
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
106105
// Skip the first cluster id because it's emitted for address calculation
107-
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
108106
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
109107
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
110108
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
@@ -146,7 +144,6 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha
146144
tt.func public @async_load_multi_cta_linear_layout(%arg0: tensor<32x32x!tt.ptr<f32>, #linear> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
147145
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
148146
// Skip the first cluster id because it's emitted for address calculation
149-
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
150147
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
151148
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
152149
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]

test/Conversion/amd/math-denorm-handling.mlir

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
6464
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
6565
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
6666
tt.func public @test_sqrt_rn_f32(%arg0: tensor<64xf32, #blocked>) {
67-
// LLVM_FTZ-LABEL: test_sqrt_rn_f32
68-
// LLVM_FTZ: llvm.amdgcn.rsq.f32
69-
// LLVM_FTZ: llvm.fmul
70-
// LLVM_FTZ: llvm.fmul
71-
// LLVM_FTZ: llvm.fneg
72-
// LLVM_FTZ: llvm.intr.fma
73-
// LLVM_FTZ-NEXT: llvm.intr.fma
74-
// LLVM_FTZ-NEXT: llvm.intr.fma
75-
// LLVM_FTZ-NEXT: llvm.fneg
76-
// LLVM_FTZ-NEXT: llvm.intr.fma
77-
// LLVM_FTZ-NEXT: llvm.intr.fma
78-
// LLVM_FTZ-NEXT: llvm.intr.is.fpclass
79-
// LLVM_FTZ-NEXT: llvm.select
80-
//
81-
// LLVM_NO_FTZ-LABEL: test_sqrt_rn_f32
82-
// LLVM_NO_FTZ: llvm.intr.sqrt
67+
// COMMON-LABEL: test_sqrt_rn_f32
68+
// COMMON: llvm.intr.sqrt
8369
%0 = tt.precise_sqrt %arg0 : tensor<64xf32, #blocked>
8470
tt.return
8571
}
@@ -96,3 +82,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9682
tt.return
9783
}
9884
}
85+
86+
// -----
87+
88+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
89+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
90+
tt.func public @test_divf_rn_f32(%arg0: tensor<64xf32, #blocked>, %arg1: tensor<64xf32, #blocked>) {
91+
// COMMON-LABEL: test_divf_rn_f32
92+
// COMMON: llvm.fdiv
93+
%0 = tt.precise_divf %arg0, %arg1 : tensor<64xf32, #blocked>
94+
tt.return
95+
}
96+
}

test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
200200
tt.return
201201
}
202202
}
203+
204+
// -----
205+
206+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [32, 0]], block = []}>
207+
#linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [16, 0], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [0, 0]], block = []}>
208+
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [4, 1], instrShape=[16, 16, 128]}>
209+
210+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
211+
// CHECK-LABEL: wmma_scaled_dot_fp8_chained
212+
tt.func @wmma_scaled_dot_fp8_chained(%arg0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg2: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, %out0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
213+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
214+
%scale0 = arith.constant dense<127> : tensor<128x4xi8, #linear>
215+
%scale1 = arith.constant dense<127> : tensor<128x4xi8, #linear1>
216+
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32>
217+
%mm0 = tt.dot_scaled %arg0 scale %scale0, %arg2 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma>
218+
// CHECK-NOT: rocdl.ds_swizzle
219+
// CHECK-NOT: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
220+
%op0 = ttg.convert_layout %mm0 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
221+
%op1 = tt.fp_to_fp %op0, rounding = rtne : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> -> tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
222+
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32>
223+
%mm1 = tt.dot_scaled %op1 scale %scale0, %arg3 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma>
224+
%ptr0 = tt.splat %out0 : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #mma>
225+
tt.store %ptr0, %mm1 : tensor<128x128x!tt.ptr<f32>, #mma>
226+
tt.return
227+
}
228+
}

0 commit comments

Comments
 (0)