@@ -52,7 +52,8 @@ static llvm::SmallVector<ValueViewOp, 4> SortValueViews(SairProgramOp program) {
5252 return sorted;
5353}
5454
55- Buffer::Buffer (mlir::Type element_type, int rank, ComputeOp op, int result)
55+ Buffer::Buffer (mlir::Type element_type, int rank, ComputeOp op, int result,
56+ const LoopFusionAnalysis &fusion_analysis)
5657 : loc_(op.getLoc()), element_type_(element_type) {
5758 assert (element_type != nullptr );
5859 assert (rank >= 0 );
@@ -67,28 +68,28 @@ Buffer::Buffer(mlir::Type element_type, int rank, ComputeOp op, int result)
6768 auto none_expr = MappingNoneExpr::get (element_type.getContext ());
6869 layout_.resize (rank, none_expr);
6970 writes_.emplace_back (op, result);
71+
72+ loop_nest_mapping_ = fusion_analysis.GetLoopNest (loop_nest_).domain_to_loops ;
7073}
7174
72- void Buffer::TrimLoopNest (int new_size,
73- const LoopFusionAnalysis &fusion_analysis) {
75+ void Buffer::TrimLoopNest (int new_size) {
7476 assert (new_size <= loop_nest_.size ());
7577 loop_nest_.resize (new_size);
78+ loop_nest_mapping_ = loop_nest_mapping_.Resize (new_size);
7679 if (domain_.empty ()) return ;
7780
7881 mlir::MLIRContext *context = element_type_.getContext ();
7982 // Compute dimensions used by layout.
80- llvm::SmallBitVector used_dimensions (domain_. size () );
83+ llvm::SmallBitVector used_dimensions = loop_nest_mapping_. DependencyMask ( );
8184 for (MappingExpr layout_expr : layout_) {
8285 layout_expr.SetDependenciesInMask (used_dimensions);
8386 }
8487
8588 // Trim domain from unused dimensions.
8689 llvm::SmallVector<ValueAccess> old_domain;
8790 std::swap (old_domain, domain_);
88- llvm::append_range (domain_, fusion_analysis.GetLoopNest (loop_nest_).domain );
8991 llvm::SmallVector<MappingExpr> renaming (old_domain.size (),
9092 MappingNoneExpr::get (context));
91-
9293 for (int dim : used_dimensions.set_bits ()) {
9394 // Already added to the new domain.
9495 if (renaming[dim].isa <MappingDimExpr>()) continue ;
@@ -114,6 +115,19 @@ void Buffer::AddWrite(ComputeOp op, int result) {
114115 writes_.emplace_back (op, result);
115116}
116117
118+ void Buffer::AddRead (ComputeOp op, int operand) {
119+ reads_.emplace_back (op, operand);
120+ }
121+
122+ MappingAttr Buffer::PrefixedLayout () const {
123+ mlir::MLIRContext *context = loop_nest_mapping_.getContext ();
124+ llvm::SmallVector<MappingExpr> exprs;
125+ exprs.reserve (loop_nest_.size () + layout_.size ());
126+ llvm::append_range (exprs, loop_nest_mapping_);
127+ llvm::append_range (exprs, layout_);
128+ return MappingAttr::get (context, domain_.size (), exprs);
129+ }
130+
117131StorageAnalysis::StorageAnalysis (mlir::Operation *operation)
118132 : StorageAnalysis(operation->getContext ()) {
119133 mlir::LogicalResult result = Init (cast<SairProgramOp>(operation));
@@ -234,7 +248,8 @@ static mlir::LogicalResult DeclareBuffer(
234248 op->getResult (result).getType ().cast <ValueType>().ElementType ();
235249
236250 int rank = attr.layout ().mapping ().size ();
237- auto it = buffer_map.try_emplace (attr.name (), element_type, rank, op, result);
251+ auto it = buffer_map.try_emplace (attr.name (), element_type, rank, op, result,
252+ fusion_analysis);
238253 Buffer &buffer = it.first ->second ;
239254
240255 if (!it.second ) {
@@ -281,7 +296,7 @@ static mlir::LogicalResult DeclareBuffer(
281296 }
282297 }
283298
284- buffer.TrimLoopNest (num_deps, fusion_analysis );
299+ buffer.TrimLoopNest (num_deps);
285300 return mlir::success ();
286301}
287302
@@ -408,20 +423,13 @@ static mlir::LogicalResult CheckMallocInsertionPoint(
408423// `loop_nest`. Increases `min_num_loops` to the minimal number of loops needed
409424// in `loop_nest` for the layout to be valid.
410425static mlir::LogicalResult CheckLayoutMapping (
411- const LoopNest &loop_nest, mlir::StringAttr buffer_name,
412- const Buffer &buffer, const IterationSpaceAnalysis &iteration_spaces,
413- int &min_num_loops) {
414- mlir::MLIRContext *context = loop_nest.domain_to_loops .getContext ();
426+ mlir::StringAttr buffer_name, const Buffer &buffer,
427+ const IterationSpaceAnalysis &iteration_spaces, int &min_num_loops) {
428+ mlir::MLIRContext *context = buffer_name.getContext ();
415429 int domain_size = buffer.domain ().size ();
416430 int loop_nest_size = buffer.loop_nest ().size ();
417431
418- // Get a mapping that maps both loop-nest and layout indices. This corresponds
419- // to the different instances of the buffer.
420- llvm::SmallVector<MappingExpr> exprs;
421- exprs.reserve (loop_nest_size + buffer.layout ().size ());
422- llvm::append_range (exprs, loop_nest.domain_to_loops );
423- llvm::append_range (exprs, buffer.layout ());
424- auto mapping = MappingAttr::get (context, domain_size, exprs);
432+ MappingAttr mapping = buffer.PrefixedLayout ();
425433
426434 // Update `min_num_loops` based on domain dimensions layout depends on.
427435 llvm::SmallBitVector used_dimensions (domain_size);
@@ -499,13 +507,13 @@ mlir::LogicalResult StorageAnalysis::Init(SairProgramOp program) {
499507 // Check that layout mapping is correct and compute the minimal loop nest
500508 // each buffer needs to be nested in.
501509 int min_num_loops = 0 ;
502- if (mlir::failed (CheckLayoutMapping (loop_nest, name, buffer,
503- iteration_spaces, min_num_loops))) {
510+ if (mlir::failed (CheckLayoutMapping (name, buffer, iteration_spaces ,
511+ min_num_loops))) {
504512 return mlir::failure ();
505513 }
506514
507515 // Minimize layout loop-nest.
508- buffer.TrimLoopNest (min_num_loops, fusion_analysis );
516+ buffer.TrimLoopNest (min_num_loops);
509517 }
510518
511519 // Compute value storages.
@@ -529,6 +537,16 @@ mlir::LogicalResult StorageAnalysis::Init(SairProgramOp program) {
529537 }
530538 }
531539
540+ // Register buffer reads.
541+ program.walk ([&](ComputeOp op) {
542+ for (ValueOperand operand : cast<SairOp>(*op).ValueOperands ()) {
543+ const ValueStorage &storage = value_storages_[operand.value ()];
544+ if (storage.buffer_name == nullptr ) continue ;
545+ buffers_.find (storage.buffer_name )
546+ ->second .AddRead (op, operand.position ());
547+ }
548+ });
549+
532550 return mlir::success ();
533551}
534552
0 commit comments