diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 22889e46f9..6aa84e6b7e 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1079,11 +1079,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver, LogicalResult AxisInfoAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { - // TODO: For sure not the right way to do this - // but why is scf.if not initialized otherwise? + // If any operands are not yet ready, skip this operation for now. for (auto op : operands) if (op->getValue().getRank() == 0) - setToEntryState((dataflow::Lattice *)op); + return success(); AxisInfo curr = visitors.apply(op, operands); if (curr.getRank() == 0) { setAllToEntryStates(results); @@ -1112,9 +1111,11 @@ void AxisInfoAnalysis::visitForOpInductionVar( ProgramPoint *programPoint = getProgramPointAfter(op); auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound()); auto *stepLattice = getLatticeElementFor(programPoint, op.getStep()); - for (auto op_iter : {lbLattice, stepLattice}) - if (op_iter->getValue().getRank() == 0) - setToEntryState((dataflow::Lattice *)op_iter); + // If lb or step is not yet ready, skip this operation for now. + if (lbLattice->getValue().getRank() == 0 || + stepLattice->getValue().getRank() == 0) { + return; + } AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); @@ -1188,24 +1189,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, &knownContiguity, &knownDivisibility, &knownConstancy); - } else if (isa( - op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. + } else if (isa(op)) { + // Initialize the arguments to gpu::WarpSpecializePartitionsOp with + // "unknown" state: the maximum possible divisibility, contiguity, and + // constancy. knownDivisibility = DimVectorT(rank, kMaxDivisor); knownConstancy = DimVectorT(rank, kMaxDivisor); knownContiguity = DimVectorT(rank, kMaxDivisor); } } else if (Operation *op = value.getDefiningOp()) { - if (isa(op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. - knownDivisibility = DimVectorT(rank, kMaxDivisor); - knownConstancy = DimVectorT(rank, kMaxDivisor); - knownContiguity = DimVectorT(rank, kMaxDivisor); - } // Other operations are conservatively initialized with the lowest possible // divisibility, contiguity, and constancy unless they have specified. AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"), @@ -1358,6 +1350,10 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp, auto *axisInfoMap = getFuncData(funcOp); auto updateAxisInfoMap = [&](Value value) { auto axisInfo = analysis->getLatticeElement(value)->getValue(); + // If we could not determine the AxisInfo for this value, assume the + // pessimistic state. + if (axisInfo.getRank() == 0) + axisInfo = AxisInfo::getPessimisticValueState(value); auto &valInfo = (*axisInfoMap)[value]; valInfo = AxisInfo::join(axisInfo, valInfo); }; diff --git a/test/Analysis/intel/test-axis-info.mlir b/test/Analysis/intel/test-axis-info.mlir index be60139bec..1a1665ffca 100644 --- a/test/Analysis/intel/test-axis-info.mlir +++ b/test/Analysis/intel/test-axis-info.mlir @@ -516,14 +516,14 @@ tt.func @for_if(%i1: i1, %arg0: !tt.ptr {tt.divisibility = 16 : i32}) { %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> %2 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg1 = %1) -> (tensor<128x64x!tt.ptr>): i32 { // CHECK: scf.if - // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = %3 = scf.if %i1 -> (tensor<128x64x!tt.ptr>) { scf.yield %arg1 : tensor<128x64x!tt.ptr> } else { scf.yield %arg1 : tensor<128x64x!tt.ptr> } // CHECK: tt.addptr - // CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 64], constant_value = + // CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = %4 = tt.addptr %3, %cst : tensor<128x64x!tt.ptr>, tensor<128x64xi32> // CHECK: scf.for // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = @@ -551,9 +551,9 @@ tt.func @for_if_for(%i1: i1, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, % // CHECK: scf.for // CHECK: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = // CHECK: scf.if - // CHECK: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = // CHECK: tt.addptr - // CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 64], constant_value = + // CHECK-SAME: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = // CHECK: scf.for // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = %3 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg2 = %1) -> (tensor<128x64x!tt.ptr>) : i32 { diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 49821e969f..7f17dca4f9 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -1089,3 +1089,18 @@ tt.func public @test_inductor_for() { } tt.return } + +// ----- + +// Verify that if an operation is statically determined to be dead, we fall back +// to assigning it a pessimistic value, rather than skipping it entirely. +tt.func @dead_op_pessimistic() { + %c5 = arith.constant dense<5> : tensor<4xi32> + %c7 = arith.constant dense<7> : tensor<4xi32> + %false = arith.constant false + scf.if %false { + // expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %add = arith.addi %c5, %c7 : tensor<4xi32> + } + tt.return +}