Skip to content

Commit 1d628d8

Browse files
ulysseBcopybara-github
authored andcommitted
Materialize buffers pass.
Add a pass that introduces alloc and free operations for buffers. PiperOrigin-RevId: 361105497
1 parent 38aa6f6 commit 1d628d8

File tree

12 files changed

+444
-41
lines changed

12 files changed

+444
-41
lines changed

loop_nest.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,9 +789,26 @@ LoopNest LoopFusionAnalysis::GetLoopNest(
789789
}
790790
result.domain_to_loops =
791791
MappingAttr::get(context_, result.domain.size(), iters);
792-
result.shape = DomainShapeAttr::get(context_, shape_dims);
792+
result.normalized_shape = DomainShapeAttr::get(context_, shape_dims);
793793

794794
return result;
795795
}
796796

797+
DomainShapeAttr LoopNest::DomainShape() const {
798+
llvm::SmallVector<DomainShapeDim> shape_dims;
799+
shape_dims.reserve(domain.size());
800+
MappingAttr loops_to_domain = domain_to_loops.Inverse();
801+
for (const ValueAccess &access : domain) {
802+
MappingAttr mapping = loops_to_domain.Resize(shape_dims.size())
803+
.Inverse()
804+
.Compose(access.mapping);
805+
shape_dims.emplace_back(access.value.getType().cast<RangeType>(), mapping);
806+
}
807+
return DomainShapeAttr::get(normalized_shape.getContext(), shape_dims);
808+
}
809+
810+
DomainShapeAttr LoopNest::Shape() const {
811+
return DomainShape().AccessedShape(domain_to_loops);
812+
}
813+
797814
} // namespace sair

loop_nest.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,19 @@ struct LoopFusionClass {
111111
struct LoopNest {
112112
// Domain used to define loop ranges.
113113
llvm::ArrayRef<ValueAccess> domain;
114+
114115
// Mapping from `domain` to loops.
115116
MappingAttr domain_to_loops;
116-
// Shape of the resulting loop nest.
117-
DomainShapeAttr shape;
117+
118+
// Shape of the nest, normalized so that dependencies between dimensions are
119+
// identity mappings.
120+
DomainShapeAttr normalized_shape;
121+
122+
// Shape of the domain the loop nest is defined from.
123+
DomainShapeAttr DomainShape() const;
124+
125+
// Shape of the loop nest.
126+
DomainShapeAttr Shape() const;
118127
};
119128

120129
// Computes loop fusion classes in a sair program.

storage.cc

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ static llvm::SmallVector<ValueViewOp, 4> SortValueViews(SairProgramOp program) {
5252
return sorted;
5353
}
5454

55-
Buffer::Buffer(mlir::Type element_type, int rank, ComputeOp op, int result)
55+
Buffer::Buffer(mlir::Type element_type, int rank, ComputeOp op, int result,
56+
const LoopFusionAnalysis &fusion_analysis)
5657
: loc_(op.getLoc()), element_type_(element_type) {
5758
assert(element_type != nullptr);
5859
assert(rank >= 0);
@@ -67,28 +68,28 @@ Buffer::Buffer(mlir::Type element_type, int rank, ComputeOp op, int result)
6768
auto none_expr = MappingNoneExpr::get(element_type.getContext());
6869
layout_.resize(rank, none_expr);
6970
writes_.emplace_back(op, result);
71+
72+
loop_nest_mapping_ = fusion_analysis.GetLoopNest(loop_nest_).domain_to_loops;
7073
}
7174

72-
void Buffer::TrimLoopNest(int new_size,
73-
const LoopFusionAnalysis &fusion_analysis) {
75+
void Buffer::TrimLoopNest(int new_size) {
7476
assert(new_size <= loop_nest_.size());
7577
loop_nest_.resize(new_size);
78+
loop_nest_mapping_ = loop_nest_mapping_.Resize(new_size);
7679
if (domain_.empty()) return;
7780

7881
mlir::MLIRContext *context = element_type_.getContext();
7982
// Compute dimensions used by layout.
80-
llvm::SmallBitVector used_dimensions(domain_.size());
83+
llvm::SmallBitVector used_dimensions = loop_nest_mapping_.DependencyMask();
8184
for (MappingExpr layout_expr : layout_) {
8285
layout_expr.SetDependenciesInMask(used_dimensions);
8386
}
8487

8588
// Trim domain from unused dimensions.
8689
llvm::SmallVector<ValueAccess> old_domain;
8790
std::swap(old_domain, domain_);
88-
llvm::append_range(domain_, fusion_analysis.GetLoopNest(loop_nest_).domain);
8991
llvm::SmallVector<MappingExpr> renaming(old_domain.size(),
9092
MappingNoneExpr::get(context));
91-
9293
for (int dim : used_dimensions.set_bits()) {
9394
// Already added to the new domain.
9495
if (renaming[dim].isa<MappingDimExpr>()) continue;
@@ -114,6 +115,19 @@ void Buffer::AddWrite(ComputeOp op, int result) {
114115
writes_.emplace_back(op, result);
115116
}
116117

118+
void Buffer::AddRead(ComputeOp op, int operand) {
119+
reads_.emplace_back(op, operand);
120+
}
121+
122+
MappingAttr Buffer::PrefixedLayout() const {
123+
mlir::MLIRContext *context = loop_nest_mapping_.getContext();
124+
llvm::SmallVector<MappingExpr> exprs;
125+
exprs.reserve(loop_nest_.size() + layout_.size());
126+
llvm::append_range(exprs, loop_nest_mapping_);
127+
llvm::append_range(exprs, layout_);
128+
return MappingAttr::get(context, domain_.size(), exprs);
129+
}
130+
117131
StorageAnalysis::StorageAnalysis(mlir::Operation *operation)
118132
: StorageAnalysis(operation->getContext()) {
119133
mlir::LogicalResult result = Init(cast<SairProgramOp>(operation));
@@ -234,7 +248,8 @@ static mlir::LogicalResult DeclareBuffer(
234248
op->getResult(result).getType().cast<ValueType>().ElementType();
235249

236250
int rank = attr.layout().mapping().size();
237-
auto it = buffer_map.try_emplace(attr.name(), element_type, rank, op, result);
251+
auto it = buffer_map.try_emplace(attr.name(), element_type, rank, op, result,
252+
fusion_analysis);
238253
Buffer &buffer = it.first->second;
239254

240255
if (!it.second) {
@@ -281,7 +296,7 @@ static mlir::LogicalResult DeclareBuffer(
281296
}
282297
}
283298

284-
buffer.TrimLoopNest(num_deps, fusion_analysis);
299+
buffer.TrimLoopNest(num_deps);
285300
return mlir::success();
286301
}
287302

@@ -408,20 +423,13 @@ static mlir::LogicalResult CheckMallocInsertionPoint(
408423
// `loop_nest`. Increases `min_num_loops` to the minimal number of loops needed
409424
// in `loop_nest` for the layout to be valid.
410425
static mlir::LogicalResult CheckLayoutMapping(
411-
const LoopNest &loop_nest, mlir::StringAttr buffer_name,
412-
const Buffer &buffer, const IterationSpaceAnalysis &iteration_spaces,
413-
int &min_num_loops) {
414-
mlir::MLIRContext *context = loop_nest.domain_to_loops.getContext();
426+
mlir::StringAttr buffer_name, const Buffer &buffer,
427+
const IterationSpaceAnalysis &iteration_spaces, int &min_num_loops) {
428+
mlir::MLIRContext *context = buffer_name.getContext();
415429
int domain_size = buffer.domain().size();
416430
int loop_nest_size = buffer.loop_nest().size();
417431

418-
// Get a mapping that maps both loop-nest and layout indices. This corresponds
419-
// to the different instances of the buffer.
420-
llvm::SmallVector<MappingExpr> exprs;
421-
exprs.reserve(loop_nest_size + buffer.layout().size());
422-
llvm::append_range(exprs, loop_nest.domain_to_loops);
423-
llvm::append_range(exprs, buffer.layout());
424-
auto mapping = MappingAttr::get(context, domain_size, exprs);
432+
MappingAttr mapping = buffer.PrefixedLayout();
425433

426434
// Update `min_num_loops` based on domain dimensions layout depends on.
427435
llvm::SmallBitVector used_dimensions(domain_size);
@@ -499,13 +507,13 @@ mlir::LogicalResult StorageAnalysis::Init(SairProgramOp program) {
499507
// Check that layout mapping is correct and compute the minimal loop nest
500508
// each buffer needs to be nested in.
501509
int min_num_loops = 0;
502-
if (mlir::failed(CheckLayoutMapping(loop_nest, name, buffer,
503-
iteration_spaces, min_num_loops))) {
510+
if (mlir::failed(CheckLayoutMapping(name, buffer, iteration_spaces,
511+
min_num_loops))) {
504512
return mlir::failure();
505513
}
506514

507515
// Minimize layout loop-nest.
508-
buffer.TrimLoopNest(min_num_loops, fusion_analysis);
516+
buffer.TrimLoopNest(min_num_loops);
509517
}
510518

511519
// Compute value storages.
@@ -529,6 +537,16 @@ mlir::LogicalResult StorageAnalysis::Init(SairProgramOp program) {
529537
}
530538
}
531539

540+
// Register buffer reads.
541+
program.walk([&](ComputeOp op) {
542+
for (ValueOperand operand : cast<SairOp>(*op).ValueOperands()) {
543+
const ValueStorage &storage = value_storages_[operand.value()];
544+
if (storage.buffer_name == nullptr) continue;
545+
buffers_.find(storage.buffer_name)
546+
->second.AddRead(op, operand.position());
547+
}
548+
});
549+
532550
return mlir::success();
533551
}
534552

storage.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class Buffer {
3636
// Create a new buffer written to by the given operation. The operation must
3737
// have a loop_nest attribute set. `result` is the position of `op` result
3838
// stored in `buffer`.
39-
Buffer(mlir::Type element_type, int rank, ComputeOp op, int result);
39+
Buffer(mlir::Type element_type, int rank, ComputeOp op, int result,
40+
const LoopFusionAnalysis &fusion_analysis);
4041

4142
// Number of dimensions in the buffer layout.
4243
int rank() const { return layout_.size(); }
@@ -58,25 +59,40 @@ class Buffer {
5859
// result stored in the buffer. Never empty.
5960
llvm::ArrayRef<std::pair<ComputeOp, int>> writes() const { return writes_; }
6061

62+
// List of operations that read from the buffer, with the position of the Sair
63+
// value operand.
64+
llvm::ArrayRef<std::pair<ComputeOp, int>> reads() const { return reads_; }
65+
6166
// Get the location of the first operation defining the buffer.
6267
mlir::Location getLoc() const { return loc_; }
6368

69+
// Mapping of domain to layout prefixed by loop nest iterators. The prefix
70+
// corresponds to the different instances of the buffer.
71+
MappingAttr PrefixedLayout() const;
72+
6473
// Registers an operation writting to the buffer.
6574
void AddWrite(ComputeOp op, int result);
6675

76+
// Registers an operation reading the buffer.
77+
void AddRead(ComputeOp op, int operand);
78+
6779
// Trims the loop-nest to the given size.
68-
void TrimLoopNest(int new_size, const LoopFusionAnalysis &fusion_analysis);
80+
void TrimLoopNest(int new_size);
6981

7082
// Unifies a dimension of the layout with another expression.
7183
void UnifyLayoutDim(int layout_dim, MappingExpr expr);
7284

7385
private:
7486
mlir::Location loc_;
7587
mlir::Type element_type_;
88+
7689
llvm::SmallVector<mlir::StringAttr> loop_nest_;
90+
MappingAttr loop_nest_mapping_;
91+
7792
llvm::SmallVector<ValueAccess> domain_;
7893
llvm::SmallVector<MappingExpr> layout_;
7994
llvm::SmallVector<std::pair<ComputeOp, int>> writes_;
95+
llvm::SmallVector<std::pair<ComputeOp, int>> reads_;
8096
};
8197

8298
// Computes buffers metadata and storage information for each value.
@@ -95,6 +111,11 @@ class StorageAnalysis {
95111
return buffers_.find(buffer)->second;
96112
}
97113

114+
// List of buffers indexed by name.
115+
const llvm::DenseMap<mlir::Attribute, Buffer> &buffers() const {
116+
return buffers_;
117+
}
118+
98119
// Retrieves the storage of a value.
99120
const ValueStorage &GetStorage(mlir::Value value) const {
100121
return value_storages_.find(value)->second;

test/materialize_buffers.mlir

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// RUN: sair-opt %s -sair-materialize-buffers -mlir-print-local-scope | FileCheck %s
2+
3+
// CHECK-LABEL: @static_shape
4+
func @static_shape(%arg0: f32) {
5+
sair.program {
6+
%0 = sair.from_scalar %arg0 : !sair.value<(), f32>
7+
%1 = sair.static_range 16 step 2: !sair.range
8+
// CHECK: %[[V0:.*]] = sair.alloc {loop_nest = [],
9+
// CHECK-SAME: storage = [{layout = #sair.named_mapping<[] -> ()>, space = "register"}]
10+
// CHECK-SAME: : !sair.value<(), memref<8xf32>>
11+
// CHECK: sair.copy[d0:%{{.*}}]
12+
%2 = sair.copy[d0:%1] %0 {
13+
loop_nest = [{name = "A", iter = #sair.mapping_expr<d0>}],
14+
storage = [{
15+
name = "B", space = "memory",
16+
layout = #sair.named_mapping<[d0:"A"] -> (d0)>
17+
}]
18+
} : !sair.value<d0:range, f32>
19+
// CHECK: sair.copy[d0:%{{.*}}]
20+
%3 = sair.copy[d0:%1] %2(d0) {
21+
loop_nest = [{name = "B", iter = #sair.mapping_expr<d0>}],
22+
storage = [{space = "register", layout = #sair.named_mapping<[] -> ()>}]
23+
} : !sair.value<d0:range, f32>
24+
// CHECK: sair.free %[[V0]] {loop_nest = []} : !sair.value<(), memref<8xf32>>
25+
sair.exit
26+
}
27+
return
28+
}
29+
30+
// CHECK-LABEL: @dynamic_shape
31+
func @dynamic_shape(%arg0: f32, %arg1: index, %arg2: index) {
32+
sair.program {
33+
%0 = sair.from_scalar %arg0 : !sair.value<(), f32>
34+
// CHECK: %[[V1:.*]] = sair.from_scalar %{{.*}} : !sair.value<(), index>
35+
%1 = sair.from_scalar %arg1 : !sair.value<(), index>
36+
// CHECK: %[[V2:.*]] = sair.from_scalar %{{.*}} : !sair.value<(), index>
37+
%2 = sair.from_scalar %arg2 : !sair.value<(), index>
38+
%3 = sair.dyn_range %1, %2 step 4 : !sair.range
39+
// CHECK: %[[V3:.*]] = sair.map %[[V1]], %[[V2]] attributes {
40+
// CHECK: loop_nest = []
41+
// CHECK: storage = [{layout = #sair.named_mapping<[] -> ()>, space = "register"}]
42+
// CHECK: } {
43+
// CHECK: ^{{.*}}(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index):
44+
// CHECK: %[[V4:.*]] = affine.apply
45+
// CHECK: affine_map<(d0, d1) -> ((d1 - d0) ceildiv 4)>(%[[ARG0]], %[[ARG1]])
46+
// CHECK: sair.return %[[V4]]
47+
// CHECK: } : #sair.shape<()>, (index, index) -> index
48+
49+
// CHECK: %[[V5:.*]] = sair.alloc %[[V3]]
50+
// CHECK: : !sair.value<(), memref<?xf32>>
51+
%4 = sair.copy[d0:%3] %0 {
52+
loop_nest = [{name = "A", iter = #sair.mapping_expr<d0>}],
53+
storage = [{
54+
name = "B", space = "memory",
55+
layout = #sair.named_mapping<[d0:"A"] -> (d0)>
56+
}]
57+
} : !sair.value<d0:range, f32>
58+
sair.exit
59+
}
60+
return
61+
}
62+
63+
// CHECK-LABEL: @loop_nest
64+
func @loop_nest(%arg0: f32) {
65+
sair.program {
66+
%0 = sair.from_scalar %arg0 : !sair.value<(), f32>
67+
%1 = sair.static_range 16 : !sair.range
68+
// CHECK: %[[D0:.*]] = sair.placeholder : !sair.range
69+
70+
// CHECK: %[[V0:.*]] = sair.map[d0:%[[D0]]] attributes {
71+
// CHECK: loop_nest = [{iter = #sair.mapping_expr<d0>, name = "A"}]
72+
// CHECK: } {
73+
// CHECK: ^{{.*}}(%[[ARG0:.*]]: index):
74+
// CHECK: %[[V1:.*]] = affine.apply affine_map<(d0) -> (d0)>(%[[ARG0]])
75+
// CHECK: %[[C4:.*]] = constant 4
76+
// CHECK: %[[V2:.*]] = addi %[[V1]], %[[C4]]
77+
// CHECK: %[[C16:.*]] = constant 16
78+
// CHECK: %[[V3:.*]] = cmpi ult, %[[C16]], %[[V2]]
79+
// CHECK: %[[V4:.*]] = select %[[V3]], %[[C16]], %[[V2]]
80+
// CHECK: %[[V5:.*]] = affine.apply affine_map<(d0, d1) -> (d1 - d0)>
81+
// CHECK: sair.return %[[V5]] : index
82+
// CHECK: } : #sair.shape<d0:range>, () -> index
83+
84+
// CHECK: %[[V6:.*]] = sair.alloc[d0:%[[D0]]] %[[V0]](d0) {
85+
// CHECK: loop_nest = [{iter = #sair.mapping_expr<d0>, name = "A"}]
86+
// CHECK: } : !sair.value<d0:range, memref<?xf32>>
87+
// CHECK: sair.copy
88+
%2 = sair.copy[d0:%1] %0 {
89+
loop_nest = [
90+
{name = "A", iter = #sair.mapping_expr<stripe(d0, 4)>},
91+
{name = "B", iter = #sair.mapping_expr<stripe(d0, 1 size 4)>}
92+
],
93+
storage = [{
94+
name = "buf", space = "memory",
95+
layout = #sair.named_mapping<[d0:"B"] -> (d0)>
96+
}]
97+
} : !sair.value<d0:range, f32>
98+
// CHECK: sair.copy
99+
%3 = sair.copy[d0:%1] %2(d0) {
100+
loop_nest = [
101+
{name = "A", iter = #sair.mapping_expr<stripe(d0, 4)>},
102+
{name = "C", iter = #sair.mapping_expr<stripe(d0, 1 size 4)>}
103+
],
104+
storage = [{
105+
name = "buf", space = "memory",
106+
layout = #sair.named_mapping<[d0:"C"] -> (d0)>
107+
}]
108+
} : !sair.value<d0:range, f32>
109+
// CHECK: sair.free[d0:%[[D0]]] %[[V6]](d0) {
110+
// CHECK: loop_nest = [{iter = #sair.mapping_expr<d0>, name = "A"}]
111+
// CHECK: } : !sair.value<d0:range, memref<?xf32>>
112+
sair.exit
113+
}
114+
return
115+
}

transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ add_mlir_library(sair_lowering
6565
introduce_loops.cc
6666
introduce_memrefs.cc
6767
lower_to_map.cc
68+
materialize_buffers.cc
6869
normalize_loops.cc
6970

7071
DEPENDS

transforms/lowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> CreateInlineTrivialOpsPass();
5555
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
5656
CreateMaterializeMemRefsPass();
5757

58+
// Replaces Sair values by buffers as specified by the `storage` attribute.
59+
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
60+
CreateMaterializeBuffersPass();
61+
5862
// Replaces iteration dimensions by loops in sair.map and sair.map_reduce
5963
// operations.
6064
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> CreateIntroduceLoopsPass();

0 commit comments

Comments
 (0)