Skip to content

Commit 1c61f0d

Browse files
ulysseBcopybara-github
authored andcommitted
Leaner iteration space analysis interface.
Improves the interface of IterationSpaceAnalysis to return a struct of `{loop_names, mapping}` instead of the raw loop nest. This makes the analysis easier to use on the caller side. PiperOrigin-RevId: 361103472
1 parent 89a623b commit 1c61f0d

File tree

5 files changed

+154
-123
lines changed

5 files changed

+154
-123
lines changed

loop_nest.cc

Lines changed: 108 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,50 @@
1919

2020
namespace sair {
2121

22+
IterationSpace::IterationSpace(llvm::SmallVector<mlir::StringAttr> loop_names,
23+
MappingAttr domain_to_loops)
24+
: loop_names_(std::move(loop_names)), domain_to_loops_(domain_to_loops) {
25+
assert(loop_names_.size() == domain_to_loops_.size());
26+
}
27+
2228
// Infers the iteration space for the current operation from iteration space of
2329
// the given operand. Trims inner loops so than only loops iterating on
2430
// dimensions mapped by the mapping remain. The resulting loop nest may
2531
// not cover all dimensions of the current operation.
26-
static mlir::ArrayAttr InferIterationSpace(
27-
mlir::ArrayAttr operand_iteration_space, ValueOperand &operand) {
28-
mlir::MLIRContext *context = operand_iteration_space.getContext();
32+
static IterationSpace InferIterationSpace(
33+
const IterationSpace &operand_iteration_space, ValueOperand &operand) {
2934
MappingAttr mapping = operand.Mapping();
3035

31-
llvm::SmallVector<mlir::Attribute> iteration_space;
32-
for (mlir::Attribute attr : operand_iteration_space.getValue()) {
33-
LoopAttr loop = attr.cast<LoopAttr>();
34-
if (loop.iter().MinDomainSize() > mapping.size()) break;
35-
36-
MappingExpr new_iter = loop.iter().SubstituteDims(mapping.Dimensions());
37-
LoopAttr new_loop = LoopAttr::get(loop.name(), new_iter, context);
38-
iteration_space.push_back(new_loop);
36+
llvm::SmallVector<mlir::StringAttr> loop_names;
37+
for (auto [name, iter] :
38+
llvm::zip(operand_iteration_space.loop_names(),
39+
operand_iteration_space.domain_to_loops())) {
40+
if (iter.MinDomainSize() > mapping.size()) break;
41+
loop_names.push_back(name);
3942
}
43+
44+
MappingAttr domain_to_loops = mapping.Compose(
45+
operand_iteration_space.domain_to_loops().Resize(loop_names.size()));
46+
4047
// If the iteration space is infered from loop-carried dimensions, trim inner
4148
// parallel dimensions as inner parallel dimension open at the end of the
4249
// previous iteration along loop-carried dimension may not be open at the
4350
// beginning of the current iteration.
4451
if (operand.AllowUseBeforeDef()) {
4552
llvm::SmallBitVector carrying_dims = operand.CarryingDims();
46-
while (!iteration_space.empty()) {
47-
LoopAttr loop = iteration_space.back().cast<LoopAttr>();
48-
int domain_size = mapping.UseDomainSize();
49-
if (loop.iter().DependencyMask(domain_size).anyCommon(carrying_dims)) {
53+
int domain_size = mapping.UseDomainSize();
54+
int new_size = loop_names.size();
55+
for (; new_size > 0; --new_size) {
56+
MappingExpr expr = domain_to_loops.Dimension(new_size - 1);
57+
if (expr.DependencyMask(domain_size).anyCommon(carrying_dims)) {
5058
break;
5159
}
52-
iteration_space.pop_back();
5360
}
61+
loop_names.resize(new_size);
62+
domain_to_loops = domain_to_loops.Resize(new_size);
5463
}
55-
return mlir::ArrayAttr::get(context, iteration_space);
64+
65+
return IterationSpace(std::move(loop_names), domain_to_loops);
5666
}
5767

5868
IterationSpaceAnalysis::IterationSpaceAnalysis(SairProgramOp program_op) {
@@ -62,39 +72,60 @@ IterationSpaceAnalysis::IterationSpaceAnalysis(SairProgramOp program_op) {
6272
}
6373
}
6474

65-
llvm::ArrayRef<mlir::Attribute> IterationSpaceAnalysis::IterationSpace(
66-
SairOp op) const {
67-
return iteration_space_.find(op.getOperation())->second.getValue();
75+
const IterationSpace &IterationSpaceAnalysis::Get(SairOp op) const {
76+
return iteration_space_.find(op.getOperation())->second;
6877
}
6978

70-
llvm::ArrayRef<mlir::Attribute> IterationSpaceAnalysis::IterationSpace(
71-
mlir::Value value) const {
72-
return IterationSpace(value.getDefiningOp());
79+
const IterationSpace &IterationSpaceAnalysis::Get(mlir::Value value) const {
80+
return Get(value.getDefiningOp());
7381
}
7482

75-
mlir::ArrayAttr IterationSpaceAnalysis::ComputeIterationSpace(
83+
const IterationSpace &IterationSpaceAnalysis::ComputeIterationSpace(
7684
mlir::Operation *operation) {
7785
if (auto it = iteration_space_.find(operation);
7886
it != iteration_space_.end()) {
7987
return it->second;
8088
}
8189

8290
mlir::MLIRContext *context = operation->getContext();
83-
mlir::ArrayAttr iteration_space = mlir::ArrayAttr::get(context, {});
91+
SairOp sair_op = cast<SairOp>(operation);
92+
int domain_size = sair_op.domain().size();
93+
94+
// Handle ComputeOp case.
8495
if (auto compute_op = dyn_cast<ComputeOp>(operation)) {
85-
iteration_space = compute_op.loop_nest().getValueOr(iteration_space);
86-
} else if (auto infer_iteration_space =
87-
dyn_cast<InferIterationSpaceOp>(operation)) {
88-
// Temporarily set the loop nest to nullptr to avoid infinite recursion.
89-
iteration_space_[operation] = iteration_space;
90-
int operand_pos = infer_iteration_space.infer_iteration_space_operand();
91-
ValueOperand operand = cast<SairOp>(operation).ValueOperands()[operand_pos];
92-
mlir::Operation *defining_op = operand.value().getDefiningOp();
93-
mlir::ArrayAttr parent_iteration_space = ComputeIterationSpace(defining_op);
94-
iteration_space = InferIterationSpace(parent_iteration_space, operand);
95-
}
96-
iteration_space_[operation] = iteration_space;
97-
return iteration_space;
96+
int num_loops = compute_op.LoopNestLoops().size();
97+
llvm::SmallVector<MappingExpr> exprs;
98+
exprs.reserve(num_loops);
99+
llvm::SmallVector<mlir::StringAttr> loop_names;
100+
loop_names.reserve(num_loops);
101+
102+
for (mlir::Attribute attr : compute_op.LoopNestLoops()) {
103+
LoopAttr loop = attr.cast<LoopAttr>();
104+
loop_names.push_back(loop.name());
105+
exprs.push_back(loop.iter());
106+
}
107+
auto mapping = MappingAttr::get(context, domain_size, exprs);
108+
return iteration_space_.try_emplace(operation, loop_names, mapping)
109+
.first->second;
110+
}
111+
112+
// Temporarily set an empty iteration space to avoid infinite recursion.
113+
auto empty_mapping = MappingAttr::get(context, domain_size, {});
114+
llvm::SmallVector<mlir::StringAttr> empty_names;
115+
auto it =
116+
iteration_space_.try_emplace(operation, empty_names, empty_mapping).first;
117+
118+
auto infer_iteration_space = dyn_cast<InferIterationSpaceOp>(operation);
119+
if (infer_iteration_space == nullptr) return it->second;
120+
121+
int operand_pos = infer_iteration_space.infer_iteration_space_operand();
122+
ValueOperand operand = sair_op.ValueOperands()[operand_pos];
123+
mlir::Operation *defining_op = operand.value().getDefiningOp();
124+
const IterationSpace &parent_iteration_space =
125+
ComputeIterationSpace(defining_op);
126+
it = iteration_space_.find(operation);
127+
it->second = InferIterationSpace(parent_iteration_space, operand);
128+
return it->second;
98129
}
99130

100131
// Analysis that keeps track of dependencies between loops.
@@ -169,17 +200,15 @@ class LoopNestConstraintsAnalysis {
169200
}
170201
}
171202

172-
mlir::ArrayRef<mlir::Attribute> iteration_space =
173-
iteration_spaces.IterationSpace(operation);
203+
const IterationSpace &iteration_space = iteration_spaces.Get(operation);
174204
llvm::SmallBitVector closed_dims = op.ResultsDimDependencies();
175205
bool closed_dims_seen = false;
176-
for (mlir::Attribute attr : iteration_space) {
177-
LoopAttr loop = attr.cast<LoopAttr>();
178-
constraints.open_loops.insert(loop.name());
179-
llvm::SmallBitVector iter_dims =
180-
loop.iter().DependencyMask(op.domain().size());
206+
for (int i = 0, e = iteration_space.size(); i < e; ++i) {
207+
constraints.open_loops.insert(iteration_space.loop_names()[i]);
208+
MappingExpr expr = iteration_space.domain_to_loops().Dimension(i);
209+
llvm::SmallBitVector iter_dims = expr.DependencyMask(op.domain().size());
181210
if (iter_dims.anyCommon(closed_dims)) {
182-
constraints.closed_loops.insert(loop.name());
211+
constraints.closed_loops.insert(iteration_space.loop_names()[i]);
183212
closed_dims_seen = true;
184213
}
185214
if (closed_dims_seen) {
@@ -342,45 +371,46 @@ static mlir::LogicalResult VerifyLoopsOpen(
342371
// * `carrying_dims`: if `dependency` is a loop-carried operand, lists
343372
// dimensions carrying the value of `dependency` across iterations.
344373
static mlir::LogicalResult VerifyDependency(
345-
SairOp op, llvm::ArrayRef<mlir::Attribute> op_loop_nest,
346-
ValueAccess dependency, const llvm::SmallBitVector &dim_dependencies,
374+
SairOp op, const IterationSpace &op_loop_nest, ValueAccess dependency,
375+
const llvm::SmallBitVector &dim_dependencies,
347376
const llvm::SmallBitVector &carrying_dims,
348377
const IterationSpaceAnalysis &iteration_space_analysis,
349378
const LoopNestConstraintsAnalysis &loop_constraints_analysis) {
350-
mlir::ArrayRef<mlir::Attribute> dep_loop_nest =
351-
iteration_space_analysis.IterationSpace(dependency.value);
379+
const IterationSpace &dep_loop_nest =
380+
iteration_space_analysis.Get(dependency.value);
352381

353382
// Verify dependencies with the operand loop nest.
354-
for (auto [op_attr, dep_attr] : llvm::zip(op_loop_nest, dep_loop_nest)) {
355-
LoopAttr op_loop = op_attr.cast<LoopAttr>();
356-
LoopAttr dep_loop = dep_attr.cast<LoopAttr>();
357-
if (op_loop.name() != dep_loop.name()) break;
383+
int min_size = std::min(op_loop_nest.size(), dep_loop_nest.size());
384+
for (int i = 0; i < min_size; ++i) {
385+
if (op_loop_nest.loop_names()[i] != dep_loop_nest.loop_names()[i]) break;
358386
// Ensure that we can unify the iterator of both loops if they are fused.
359387
MappingExpr expected_expr =
360-
dep_loop.iter().SubstituteDims(dependency.mapping.Dimensions());
361-
if (expected_expr.Unify(op_loop.iter()) != nullptr) continue;
362-
return (op.emitError() << "loop " << op_loop.name()
388+
dep_loop_nest.domain_to_loops().Dimension(i).SubstituteDims(
389+
dependency.mapping.Dimensions());
390+
MappingExpr given_expr = op_loop_nest.domain_to_loops().Dimension(i);
391+
if (expected_expr.Unify(given_expr) != nullptr) continue;
392+
return (op.emitError() << "loop " << op_loop_nest.loop_names()[i]
363393
<< " violates a data dependency")
364394
.attachNote(dependency.value.getLoc())
365395
<< "dependency from this operation";
366396
}
367397

368398
const LoopNestConstraintsAnalysis::Constraints &constraints =
369399
loop_constraints_analysis.GetConstraints(dependency.value);
370-
for (mlir::Attribute attr : op_loop_nest) {
371-
LoopAttr loop = attr.cast<LoopAttr>();
372-
if (constraints.closed_loops.contains(loop.name())) {
373-
return op.emitError() << "loop " << loop.name()
374-
<< " must be closed before this operation";
400+
for (int i = 0, e = op_loop_nest.size(); i < e; ++i) {
401+
mlir::StringAttr name = op_loop_nest.loop_names()[i];
402+
if (constraints.closed_loops.contains(name)) {
403+
return op.emitError()
404+
<< "loop " << name << " must be closed before this operation";
375405
}
376406

377-
if (!constraints.open_loops.contains(loop.name())) continue;
378-
llvm::SmallBitVector iter_dims =
379-
loop.iter().DependencyMask(op.domain().size());
407+
if (!constraints.open_loops.contains(name)) continue;
408+
MappingExpr expr = op_loop_nest.domain_to_loops().Dimension(i);
409+
llvm::SmallBitVector iter_dims = expr.DependencyMask(op.domain().size());
380410
if (!dim_dependencies.anyCommon(iter_dims)) continue;
381411

382412
return (dependency.value.getDefiningOp()->emitError()
383-
<< "operation cannot be nested in loop " << loop.name())
413+
<< "operation cannot be nested in loop " << name)
384414
.attachNote(op.getLoc())
385415
<< "because of this operation";
386416
}
@@ -406,8 +436,7 @@ static mlir::LogicalResult VerifyDependency(
406436
static mlir::LogicalResult VerifyDependencies(
407437
SairOp op, IterationSpaceAnalysis &iteration_space_analysis,
408438
LoopNestConstraintsAnalysis &loop_constaints_analysis) {
409-
llvm::ArrayRef<mlir::Attribute> loop_nest =
410-
iteration_space_analysis.IterationSpace(op);
439+
const IterationSpace &loop_nest = iteration_space_analysis.Get(op);
411440

412441
int domain_size = op.domain().size();
413442
for (int i = 0; i < domain_size; ++i) {
@@ -459,14 +488,13 @@ static mlir::LogicalResult VerifyLoopRanges(
459488

460489
// Ensure that each loop only iterate along a single sub-domain.
461490
static mlir::LogicalResult VerifySubDomains(
462-
SairOp op, llvm::ArrayRef<mlir::Attribute> iteration_space) {
491+
SairOp op, const IterationSpace &iteration_space) {
463492
llvm::SmallVector<int> sub_domains = op.SubDomains();
464493
assert(!sub_domains.empty() || iteration_space.empty());
465494

466-
for (mlir::Attribute attr : iteration_space) {
467-
LoopAttr loop = attr.cast<LoopAttr>();
468-
llvm::SmallBitVector dimensions =
469-
loop.iter().DependencyMask(op.domain().size());
495+
for (int i = 0, e = iteration_space.size(); i < e; ++i) {
496+
MappingExpr expr = iteration_space.domain_to_loops().Dimension(i);
497+
llvm::SmallBitVector dimensions = expr.DependencyMask(op.domain().size());
470498
if (!dimensions.any()) continue;
471499

472500
// Compute the sub-domain the loop belongs to. If the iterator is not fully
@@ -476,7 +504,7 @@ static mlir::LogicalResult VerifySubDomains(
476504
int sub_domain = 0;
477505
int min_dim_index = 0;
478506
int max_dim_index = sub_domains[0];
479-
if (loop.iter().IsFullySpecified()) {
507+
if (expr.IsFullySpecified()) {
480508
int first = dimensions.find_first();
481509
while (first >= max_dim_index) {
482510
min_dim_index = max_dim_index;
@@ -488,8 +516,8 @@ static mlir::LogicalResult VerifySubDomains(
488516
// sub-domain.
489517
if (dimensions.find_first() < min_dim_index ||
490518
dimensions.find_last() >= max_dim_index) {
491-
return op.emitError()
492-
<< "loop " << loop.name() << " crosses sub-domains boundaries";
519+
return op.emitError() << "loop " << iteration_space.loop_names()[i]
520+
<< " crosses sub-domains boundaries";
493521
}
494522
}
495523
return mlir::success();
@@ -535,8 +563,7 @@ mlir::LogicalResult VerifyLoopNests(SairProgramOp program) {
535563

536564
// Verify dependencies.
537565
result = program.walk([&](SairOp op) -> mlir::WalkResult {
538-
if (mlir::failed(VerifySubDomains(
539-
op, iteration_space_analysis.IterationSpace(op)))) {
566+
if (mlir::failed(VerifySubDomains(op, iteration_space_analysis.Get(op)))) {
540567
return mlir::failure();
541568
}
542569
return VerifyDependencies(op, iteration_space_analysis,
@@ -725,8 +752,8 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
725752
return mlir::success();
726753
}
727754

728-
LoopNest LoopFusionAnalysis::GetLoopNest(
729-
llvm::ArrayRef<mlir::Attribute> loops) const {
755+
LoopNest LoopFusionAnalysis::GetLoopNest(ComputeOp op) const {
756+
llvm::ArrayRef<mlir::Attribute> loops = op.LoopNestLoops();
730757
llvm::SmallVector<mlir::StringAttr> loop_names;
731758
loop_names.reserve(loops.size());
732759
for (mlir::Attribute attr : loops) {
@@ -762,7 +789,7 @@ LoopNest LoopFusionAnalysis::GetLoopNest(
762789
}
763790
result.domain_to_loops =
764791
MappingAttr::get(context_, result.domain.size(), iters);
765-
result.loops_shape = DomainShapeAttr::get(context_, shape_dims);
792+
result.shape = DomainShapeAttr::get(context_, shape_dims);
766793

767794
return result;
768795
}

0 commit comments

Comments
 (0)