Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
LogicalResult AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
// TODO: For sure not the right way to do this
// but why is scf.if not initialized otherwise?
// If any operands are not yet ready, skip this operation for now.
for (auto op : operands)
if (op->getValue().getRank() == 0)
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
return success();
AxisInfo curr = visitors.apply(op, operands);
if (curr.getRank() == 0) {
setAllToEntryStates(results);
Expand Down Expand Up @@ -1112,9 +1111,11 @@ void AxisInfoAnalysis::visitForOpInductionVar(
ProgramPoint *programPoint = getProgramPointAfter(op);
auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound());
auto *stepLattice = getLatticeElementFor(programPoint, op.getStep());
for (auto op_iter : {lbLattice, stepLattice})
if (op_iter->getValue().getRank() == 0)
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
// If lb or step is not yet ready, skip this operation for now.
if (lbLattice->getValue().getRank() == 0 ||
stepLattice->getValue().getRank() == 0) {
return;
}

AxisInfo::DimVectorT knownContiguity(1, 1);
AxisInfo::DimVectorT knownDivisibility(1, 1);
Expand Down Expand Up @@ -1188,24 +1189,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
} else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
op)) {
// scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
// Control flow operations are initialized with "unknown" state:
// the maximum possible divisibility, contiguity, and constancy.
} else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
// "unknown" state: the maximum possible divisibility, contiguity, and
// constancy.
knownDivisibility = DimVectorT(rank, kMaxDivisor);
knownConstancy = DimVectorT(rank, kMaxDivisor);
knownContiguity = DimVectorT(rank, kMaxDivisor);
}
} else if (Operation *op = value.getDefiningOp()) {
if (isa<RegionBranchOpInterface>(op)) {
// scf::ForOp, scf::IfOp, scf::WhileOp
// Control flow operations are initialized with "unknown" state:
// the maximum possible divisibility, contiguity, and constancy.
knownDivisibility = DimVectorT(rank, kMaxDivisor);
knownConstancy = DimVectorT(rank, kMaxDivisor);
knownContiguity = DimVectorT(rank, kMaxDivisor);
}
// Other operations are conservatively initialized with the lowest possible
// divisibility, contiguity, and constancy unless they have specified.
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
Expand Down Expand Up @@ -1358,6 +1350,10 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
auto *axisInfoMap = getFuncData(funcOp);
auto updateAxisInfoMap = [&](Value value) {
auto axisInfo = analysis->getLatticeElement(value)->getValue();
// If we could not determine the AxisInfo for this value, assume the
// pessimistic state.
if (axisInfo.getRank() == 0)
axisInfo = AxisInfo::getPessimisticValueState(value);
auto &valInfo = (*axisInfoMap)[value];
valInfo = AxisInfo::join(axisInfo, valInfo);
};
Expand Down
8 changes: 4 additions & 4 deletions test/Analysis/intel/test-axis-info.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -516,14 +516,14 @@ tt.func @for_if(%i1: i1, %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>>
%2 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg1 = %1) -> (tensor<128x64x!tt.ptr<f16>>): i32 {
// CHECK: scf.if
// CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = <none>
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
%3 = scf.if %i1 -> (tensor<128x64x!tt.ptr<f16>>) {
scf.yield %arg1 : tensor<128x64x!tt.ptr<f16>>
} else {
scf.yield %arg1 : tensor<128x64x!tt.ptr<f16>>
}
// CHECK: tt.addptr
// CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 64], constant_value = <none>
// CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
%4 = tt.addptr %3, %cst : tensor<128x64x!tt.ptr<f16>>, tensor<128x64xi32>
// CHECK: scf.for
// CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = <none>
Expand Down Expand Up @@ -551,9 +551,9 @@ tt.func @for_if_for(%i1: i1, %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %
// CHECK: scf.for
// CHECK: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = <none>
// CHECK: scf.if
// CHECK: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = <none>
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
// CHECK: tt.addptr
// CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 64], constant_value = <none>
// CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
// CHECK: scf.for
// CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = <none>
%3 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg2 = %1) -> (tensor<128x64x!tt.ptr<f16>>) : i32 {
Expand Down
15 changes: 15 additions & 0 deletions test/Analysis/test-alignment.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1089,3 +1089,18 @@ tt.func public @test_inductor_for() {
}
tt.return
}

// -----

// Verify that if an operation is statically determined to be dead, we fall back
// to assigning it a pessimistic value, rather than skipping it entirely.
tt.func @dead_op_pessimistic() {
%c5 = arith.constant dense<5> : tensor<4xi32>
%c7 = arith.constant dense<7> : tensor<4xi32>
%false = arith.constant false
scf.if %false {
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
%add = arith.addi %c5, %c7 : tensor<4xi32>
}
tt.return
}
Loading