1919
2020namespace sair {
2121
22+ IterationSpace::IterationSpace (llvm::SmallVector<mlir::StringAttr> loop_names,
23+ MappingAttr domain_to_loops)
24+ : loop_names_(std::move(loop_names)), domain_to_loops_(domain_to_loops) {
25+ assert (loop_names_.size () == domain_to_loops_.size ());
26+ }
27+
2228// Infers the iteration space for the current operation from iteration space of
2329// the given operand. Trims inner loops so than only loops iterating on
2430// dimensions mapped by the mapping remain. The resulting loop nest may
2531// not cover all dimensions of the current operation.
26- static mlir::ArrayAttr InferIterationSpace (
27- mlir::ArrayAttr operand_iteration_space, ValueOperand &operand) {
28- mlir::MLIRContext *context = operand_iteration_space.getContext ();
32+ static IterationSpace InferIterationSpace (
33+ const IterationSpace &operand_iteration_space, ValueOperand &operand) {
2934 MappingAttr mapping = operand.Mapping ();
3035
31- llvm::SmallVector<mlir::Attribute> iteration_space;
32- for (mlir::Attribute attr : operand_iteration_space.getValue ()) {
33- LoopAttr loop = attr.cast <LoopAttr>();
34- if (loop.iter ().MinDomainSize () > mapping.size ()) break ;
35-
36- MappingExpr new_iter = loop.iter ().SubstituteDims (mapping.Dimensions ());
37- LoopAttr new_loop = LoopAttr::get (loop.name (), new_iter, context);
38- iteration_space.push_back (new_loop);
36+ llvm::SmallVector<mlir::StringAttr> loop_names;
37+ for (auto [name, iter] :
38+ llvm::zip (operand_iteration_space.loop_names (),
39+ operand_iteration_space.domain_to_loops ())) {
40+ if (iter.MinDomainSize () > mapping.size ()) break ;
41+ loop_names.push_back (name);
3942 }
43+
44+ MappingAttr domain_to_loops = mapping.Compose (
45+ operand_iteration_space.domain_to_loops ().Resize (loop_names.size ()));
46+
4047 // If the iteration space is infered from loop-carried dimensions, trim inner
4148 // parallel dimensions as inner parallel dimension open at the end of the
4249 // previous iteration along loop-carried dimension may not be open at the
4350 // beginning of the current iteration.
4451 if (operand.AllowUseBeforeDef ()) {
4552 llvm::SmallBitVector carrying_dims = operand.CarryingDims ();
46- while (!iteration_space.empty ()) {
47- LoopAttr loop = iteration_space.back ().cast <LoopAttr>();
48- int domain_size = mapping.UseDomainSize ();
49- if (loop.iter ().DependencyMask (domain_size).anyCommon (carrying_dims)) {
53+ int domain_size = mapping.UseDomainSize ();
54+ int new_size = loop_names.size ();
55+ for (; new_size > 0 ; --new_size) {
56+ MappingExpr expr = domain_to_loops.Dimension (new_size - 1 );
57+ if (expr.DependencyMask (domain_size).anyCommon (carrying_dims)) {
5058 break ;
5159 }
52- iteration_space.pop_back ();
5360 }
61+ loop_names.resize (new_size);
62+ domain_to_loops = domain_to_loops.Resize (new_size);
5463 }
55- return mlir::ArrayAttr::get (context, iteration_space);
64+
65+ return IterationSpace (std::move (loop_names), domain_to_loops);
5666}
5767
5868IterationSpaceAnalysis::IterationSpaceAnalysis (SairProgramOp program_op) {
@@ -62,39 +72,60 @@ IterationSpaceAnalysis::IterationSpaceAnalysis(SairProgramOp program_op) {
6272 }
6373}
6474
65- llvm::ArrayRef<mlir::Attribute> IterationSpaceAnalysis::IterationSpace (
66- SairOp op) const {
67- return iteration_space_.find (op.getOperation ())->second .getValue ();
75+ const IterationSpace &IterationSpaceAnalysis::Get (SairOp op) const {
76+ return iteration_space_.find (op.getOperation ())->second ;
6877}
6978
70- llvm::ArrayRef<mlir::Attribute> IterationSpaceAnalysis::IterationSpace (
71- mlir::Value value) const {
72- return IterationSpace (value.getDefiningOp ());
79+ const IterationSpace &IterationSpaceAnalysis::Get (mlir::Value value) const {
80+ return Get (value.getDefiningOp ());
7381}
7482
75- mlir::ArrayAttr IterationSpaceAnalysis::ComputeIterationSpace (
83+ const IterationSpace & IterationSpaceAnalysis::ComputeIterationSpace (
7684 mlir::Operation *operation) {
7785 if (auto it = iteration_space_.find (operation);
7886 it != iteration_space_.end ()) {
7987 return it->second ;
8088 }
8189
8290 mlir::MLIRContext *context = operation->getContext ();
83- mlir::ArrayAttr iteration_space = mlir::ArrayAttr::get (context, {});
91+ SairOp sair_op = cast<SairOp>(operation);
92+ int domain_size = sair_op.domain ().size ();
93+
94+ // Handle ComputeOp case.
8495 if (auto compute_op = dyn_cast<ComputeOp>(operation)) {
85- iteration_space = compute_op.loop_nest ().getValueOr (iteration_space);
86- } else if (auto infer_iteration_space =
87- dyn_cast<InferIterationSpaceOp>(operation)) {
88- // Temporarily set the loop nest to nullptr to avoid infinite recursion.
89- iteration_space_[operation] = iteration_space;
90- int operand_pos = infer_iteration_space.infer_iteration_space_operand ();
91- ValueOperand operand = cast<SairOp>(operation).ValueOperands ()[operand_pos];
92- mlir::Operation *defining_op = operand.value ().getDefiningOp ();
93- mlir::ArrayAttr parent_iteration_space = ComputeIterationSpace (defining_op);
94- iteration_space = InferIterationSpace (parent_iteration_space, operand);
95- }
96- iteration_space_[operation] = iteration_space;
97- return iteration_space;
96+ int num_loops = compute_op.LoopNestLoops ().size ();
97+ llvm::SmallVector<MappingExpr> exprs;
98+ exprs.reserve (num_loops);
99+ llvm::SmallVector<mlir::StringAttr> loop_names;
100+ loop_names.reserve (num_loops);
101+
102+ for (mlir::Attribute attr : compute_op.LoopNestLoops ()) {
103+ LoopAttr loop = attr.cast <LoopAttr>();
104+ loop_names.push_back (loop.name ());
105+ exprs.push_back (loop.iter ());
106+ }
107+ auto mapping = MappingAttr::get (context, domain_size, exprs);
108+ return iteration_space_.try_emplace (operation, loop_names, mapping)
109+ .first ->second ;
110+ }
111+
112+ // Temporarily set an empty iteration space to avoid infinite recursion.
113+ auto empty_mapping = MappingAttr::get (context, domain_size, {});
114+ llvm::SmallVector<mlir::StringAttr> empty_names;
115+ auto it =
116+ iteration_space_.try_emplace (operation, empty_names, empty_mapping).first ;
117+
118+ auto infer_iteration_space = dyn_cast<InferIterationSpaceOp>(operation);
119+ if (infer_iteration_space == nullptr ) return it->second ;
120+
121+ int operand_pos = infer_iteration_space.infer_iteration_space_operand ();
122+ ValueOperand operand = sair_op.ValueOperands ()[operand_pos];
123+ mlir::Operation *defining_op = operand.value ().getDefiningOp ();
124+ const IterationSpace &parent_iteration_space =
125+ ComputeIterationSpace (defining_op);
126+ it = iteration_space_.find (operation);
127+ it->second = InferIterationSpace (parent_iteration_space, operand);
128+ return it->second ;
98129}
99130
100131// Analysis that keeps track of dependencies between loops.
@@ -169,17 +200,15 @@ class LoopNestConstraintsAnalysis {
169200 }
170201 }
171202
172- mlir::ArrayRef<mlir::Attribute> iteration_space =
173- iteration_spaces.IterationSpace (operation);
203+ const IterationSpace &iteration_space = iteration_spaces.Get (operation);
174204 llvm::SmallBitVector closed_dims = op.ResultsDimDependencies ();
175205 bool closed_dims_seen = false ;
176- for (mlir::Attribute attr : iteration_space) {
177- LoopAttr loop = attr.cast <LoopAttr>();
178- constraints.open_loops .insert (loop.name ());
179- llvm::SmallBitVector iter_dims =
180- loop.iter ().DependencyMask (op.domain ().size ());
206+ for (int i = 0 , e = iteration_space.size (); i < e; ++i) {
207+ constraints.open_loops .insert (iteration_space.loop_names ()[i]);
208+ MappingExpr expr = iteration_space.domain_to_loops ().Dimension (i);
209+ llvm::SmallBitVector iter_dims = expr.DependencyMask (op.domain ().size ());
181210 if (iter_dims.anyCommon (closed_dims)) {
182- constraints.closed_loops .insert (loop. name () );
211+ constraints.closed_loops .insert (iteration_space. loop_names ()[i] );
183212 closed_dims_seen = true ;
184213 }
185214 if (closed_dims_seen) {
@@ -342,45 +371,46 @@ static mlir::LogicalResult VerifyLoopsOpen(
342371// * `carrying_dims`: if `dependency` is a loop-carried operand, lists
343372// dimensions carrying the value of `dependency` across iterations.
344373static mlir::LogicalResult VerifyDependency (
345- SairOp op, llvm::ArrayRef<mlir::Attribute> op_loop_nest,
346- ValueAccess dependency, const llvm::SmallBitVector &dim_dependencies,
374+ SairOp op, const IterationSpace & op_loop_nest, ValueAccess dependency ,
375+ const llvm::SmallBitVector &dim_dependencies,
347376 const llvm::SmallBitVector &carrying_dims,
348377 const IterationSpaceAnalysis &iteration_space_analysis,
349378 const LoopNestConstraintsAnalysis &loop_constraints_analysis) {
350- mlir::ArrayRef<mlir::Attribute> dep_loop_nest =
351- iteration_space_analysis.IterationSpace (dependency.value );
379+ const IterationSpace & dep_loop_nest =
380+ iteration_space_analysis.Get (dependency.value );
352381
353382 // Verify dependencies with the operand loop nest.
354- for (auto [op_attr, dep_attr] : llvm::zip (op_loop_nest, dep_loop_nest)) {
355- LoopAttr op_loop = op_attr.cast <LoopAttr>();
356- LoopAttr dep_loop = dep_attr.cast <LoopAttr>();
357- if (op_loop.name () != dep_loop.name ()) break ;
383+ int min_size = std::min (op_loop_nest.size (), dep_loop_nest.size ());
384+ for (int i = 0 ; i < min_size; ++i) {
385+ if (op_loop_nest.loop_names ()[i] != dep_loop_nest.loop_names ()[i]) break ;
358386 // Ensure that we can unify the iterator of both loops if they are fused.
359387 MappingExpr expected_expr =
360- dep_loop.iter ().SubstituteDims (dependency.mapping .Dimensions ());
361- if (expected_expr.Unify (op_loop.iter ()) != nullptr ) continue ;
362- return (op.emitError () << " loop " << op_loop.name ()
388+ dep_loop_nest.domain_to_loops ().Dimension (i).SubstituteDims (
389+ dependency.mapping .Dimensions ());
390+ MappingExpr given_expr = op_loop_nest.domain_to_loops ().Dimension (i);
391+ if (expected_expr.Unify (given_expr) != nullptr ) continue ;
392+ return (op.emitError () << " loop " << op_loop_nest.loop_names ()[i]
363393 << " violates a data dependency" )
364394 .attachNote (dependency.value .getLoc ())
365395 << " dependency from this operation" ;
366396 }
367397
368398 const LoopNestConstraintsAnalysis::Constraints &constraints =
369399 loop_constraints_analysis.GetConstraints (dependency.value );
370- for (mlir::Attribute attr : op_loop_nest) {
371- LoopAttr loop = attr. cast <LoopAttr>() ;
372- if (constraints.closed_loops .contains (loop. name () )) {
373- return op.emitError () << " loop " << loop. name ()
374- << " must be closed before this operation" ;
400+ for (int i = 0 , e = op_loop_nest. size (); i < e; ++i ) {
401+ mlir::StringAttr name = op_loop_nest. loop_names ()[i] ;
402+ if (constraints.closed_loops .contains (name)) {
403+ return op.emitError ()
404+ << " loop " << name << " must be closed before this operation" ;
375405 }
376406
377- if (!constraints.open_loops .contains (loop. name () )) continue ;
378- llvm::SmallBitVector iter_dims =
379- loop. iter () .DependencyMask (op.domain ().size ());
407+ if (!constraints.open_loops .contains (name)) continue ;
408+ MappingExpr expr = op_loop_nest. domain_to_loops (). Dimension (i);
409+ llvm::SmallBitVector iter_dims = expr .DependencyMask (op.domain ().size ());
380410 if (!dim_dependencies.anyCommon (iter_dims)) continue ;
381411
382412 return (dependency.value .getDefiningOp ()->emitError ()
383- << " operation cannot be nested in loop " << loop. name () )
413+ << " operation cannot be nested in loop " << name)
384414 .attachNote (op.getLoc ())
385415 << " because of this operation" ;
386416 }
@@ -406,8 +436,7 @@ static mlir::LogicalResult VerifyDependency(
406436static mlir::LogicalResult VerifyDependencies (
407437 SairOp op, IterationSpaceAnalysis &iteration_space_analysis,
408438 LoopNestConstraintsAnalysis &loop_constaints_analysis) {
409- llvm::ArrayRef<mlir::Attribute> loop_nest =
410- iteration_space_analysis.IterationSpace (op);
439+ const IterationSpace &loop_nest = iteration_space_analysis.Get (op);
411440
412441 int domain_size = op.domain ().size ();
413442 for (int i = 0 ; i < domain_size; ++i) {
@@ -459,14 +488,13 @@ static mlir::LogicalResult VerifyLoopRanges(
459488
460489// Ensure that each loop only iterate along a single sub-domain.
461490static mlir::LogicalResult VerifySubDomains (
462- SairOp op, llvm::ArrayRef<mlir::Attribute> iteration_space) {
491+ SairOp op, const IterationSpace & iteration_space) {
463492 llvm::SmallVector<int > sub_domains = op.SubDomains ();
464493 assert (!sub_domains.empty () || iteration_space.empty ());
465494
466- for (mlir::Attribute attr : iteration_space) {
467- LoopAttr loop = attr.cast <LoopAttr>();
468- llvm::SmallBitVector dimensions =
469- loop.iter ().DependencyMask (op.domain ().size ());
495+ for (int i = 0 , e = iteration_space.size (); i < e; ++i) {
496+ MappingExpr expr = iteration_space.domain_to_loops ().Dimension (i);
497+ llvm::SmallBitVector dimensions = expr.DependencyMask (op.domain ().size ());
470498 if (!dimensions.any ()) continue ;
471499
472500 // Compute the sub-domain the loop belongs to. If the iterator is not fully
@@ -476,7 +504,7 @@ static mlir::LogicalResult VerifySubDomains(
476504 int sub_domain = 0 ;
477505 int min_dim_index = 0 ;
478506 int max_dim_index = sub_domains[0 ];
479- if (loop. iter () .IsFullySpecified ()) {
507+ if (expr .IsFullySpecified ()) {
480508 int first = dimensions.find_first ();
481509 while (first >= max_dim_index) {
482510 min_dim_index = max_dim_index;
@@ -488,8 +516,8 @@ static mlir::LogicalResult VerifySubDomains(
488516 // sub-domain.
489517 if (dimensions.find_first () < min_dim_index ||
490518 dimensions.find_last () >= max_dim_index) {
491- return op.emitError ()
492- << " loop " << loop. name () << " crosses sub-domains boundaries" ;
519+ return op.emitError () << " loop " << iteration_space. loop_names ()[i]
520+ << " crosses sub-domains boundaries" ;
493521 }
494522 }
495523 return mlir::success ();
@@ -535,8 +563,7 @@ mlir::LogicalResult VerifyLoopNests(SairProgramOp program) {
535563
536564 // Verify dependencies.
537565 result = program.walk ([&](SairOp op) -> mlir::WalkResult {
538- if (mlir::failed (VerifySubDomains (
539- op, iteration_space_analysis.IterationSpace (op)))) {
566+ if (mlir::failed (VerifySubDomains (op, iteration_space_analysis.Get (op)))) {
540567 return mlir::failure ();
541568 }
542569 return VerifyDependencies (op, iteration_space_analysis,
@@ -725,8 +752,8 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
725752 return mlir::success ();
726753}
727754
728- LoopNest LoopFusionAnalysis::GetLoopNest (
729- llvm::ArrayRef<mlir::Attribute> loops) const {
755+ LoopNest LoopFusionAnalysis::GetLoopNest (ComputeOp op) const {
756+ llvm::ArrayRef<mlir::Attribute> loops = op. LoopNestLoops ();
730757 llvm::SmallVector<mlir::StringAttr> loop_names;
731758 loop_names.reserve (loops.size ());
732759 for (mlir::Attribute attr : loops) {
@@ -762,7 +789,7 @@ LoopNest LoopFusionAnalysis::GetLoopNest(
762789 }
763790 result.domain_to_loops =
764791 MappingAttr::get (context_, result.domain .size (), iters);
765- result.loops_shape = DomainShapeAttr::get (context_, shape_dims);
792+ result.shape = DomainShapeAttr::get (context_, shape_dims);
766793
767794 return result;
768795}
0 commit comments