Skip to content

Commit 9f7d367

Browse files
ulysseBcopybara-github
authored andcommitted
Replace memory_space attribute by storage attribute.
The `storage` attribute applies to `ComputeOp` operations only. It specifies an optional `BufferAttr` attribute for each output. The buffer attribute specifies: * A memory space, currently equal to "memory" or "register" * A name, for "memory" buffers only. All buffers sharing the same name are stored in the same memref. * a mapping from a subset of loops to dimensions of the memref backing the buffer. * Remove `memory_space` attribute from all operations. * Add a `storage` attribute to ComputeOp operations. * Introduce a `ValueViewOp` interface to infer the storage of non-compute operations. * Extend `LoopFusionAnalysis` to return the domain and the mapping from domain to loop iteration of loop nest. * Define a `StorageAnalysis` that checks and unify buffer attributes with the same names. As a side effect of change `memory_space` into `storage`: * Update passes using `memory_space` to use `storage` instead. * Change DefaultMemorySpace pass into a DefaultStorage pass. Only support the subset of operations that was previously supported by DefaultMemorySpace. * Update sair-from-linalg pass to insert copy operations after sair.from_memerf operations. This clears the storage of values originating from sair.from_memref so that is does not conflicts with other storage constraints. * Remove unused ValueProducer interface. * Update ResolveUnificationConstraints in `util.cc` so that it can also handle unification of buffers layouts. PiperOrigin-RevId: 361096395
1 parent ca598dc commit 9f7d367

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1812
-545
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ add_mlir_library(sair_dialect
8383
sair_ops.cc
8484
sair_types.cc
8585
util.cc
86+
storage.cc
8687

8788
DEPENDS
8889
sair_ops_inc_gen

canonicalization_patterns.cc

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ bool SimplifyProjOp(ValueOperand &use, ProjOp op,
7777
rewriter.setInsertionPoint(op);
7878
ProjOp new_op = rewriter.create<ProjOp>(
7979
op.getLoc(), op.getType(), op.parallel_domain(), projection_domain,
80-
mapping_array, prev_op.value(), shape, op.memory_spaceAttr());
80+
mapping_array, prev_op.value(), shape);
8181
use.set_value(new_op.result());
8282

8383
return true;
@@ -141,7 +141,7 @@ class DeduplicateMapInputsOutputs : public OpRewritePattern<SairMapOp> {
141141
llvm::SmallVector<mlir::Value, 4> old_results_to_keep;
142142
llvm::SmallVector<mlir::Value, 4> new_scalar_results;
143143
llvm::SmallVector<mlir::Type, 4> new_result_types;
144-
llvm::SmallVector<mlir::Attribute, 4> new_memory_spaces;
144+
llvm::SmallVector<mlir::Attribute, 4> new_storages;
145145

146146
for (ValueOperand operand : op.ValueOperands()) {
147147
mlir::Value argument =
@@ -176,7 +176,7 @@ class DeduplicateMapInputsOutputs : public OpRewritePattern<SairMapOp> {
176176
// Deduplicate results.
177177
for (int j = 0; j < i; ++j) {
178178
if (scalar_value != return_op.getOperand(j)) continue;
179-
if (op.GetMemorySpace(i) != op.GetMemorySpace(j)) continue;
179+
if (op.Storage(i) != op.Storage(j)) continue;
180180
result.replaceAllUsesWith(op.getResult(j));
181181
break;
182182
}
@@ -187,13 +187,9 @@ class DeduplicateMapInputsOutputs : public OpRewritePattern<SairMapOp> {
187187
old_results_to_keep.push_back(result);
188188
new_scalar_results.push_back(scalar_value);
189189
new_result_types.push_back(result.getType());
190-
mlir::Attribute memory_space =
191-
op.GetMemorySpace(i)
192-
.map([&](int value) -> mlir::Attribute {
193-
return rewriter.getI32IntegerAttr(value);
194-
})
195-
.getValueOr(rewriter.getUnitAttr());
196-
new_memory_spaces.push_back(memory_space);
190+
mlir::Attribute new_storage = op.Storage(i);
191+
new_storages.push_back(new_storage == nullptr ? rewriter.getUnitAttr()
192+
: new_storage);
197193
}
198194

199195
// Create the new operation if necessary.
@@ -210,7 +206,7 @@ class DeduplicateMapInputsOutputs : public OpRewritePattern<SairMapOp> {
210206
SairMapOp new_op = rewriter.create<SairMapOp>(
211207
op.getLoc(), new_result_types, op.domain(),
212208
rewriter.getArrayAttr(new_mappings), new_operands, op.shape(),
213-
op.loop_nestAttr(), rewriter.getArrayAttr(new_memory_spaces));
209+
op.loop_nestAttr(), rewriter.getArrayAttr(new_storages));
214210
new_op.body().takeBody(op.body());
215211

216212
for (auto p : llvm::zip(old_results_to_keep, new_op.results())) {
@@ -310,7 +306,7 @@ class RemoveUnreferencedDims : public OpRewritePattern<OpTy> {
310306
op.getType().template cast<ValueType>().AccessedType(partial_mapping),
311307
parallel_dimensions, projection_dimensions,
312308
rewriter.getArrayAttr(mapping.Inverse().Compose(op.Value().Mapping())),
313-
op.value(), op.shape().AccessedShape(mapping), op.memory_spaceAttr());
309+
op.value(), op.shape().AccessedShape(mapping));
314310
new_op->setDialectAttrs(op->getDialectAttrs());
315311

316312
// Replace the original op.
@@ -356,7 +352,7 @@ class RemoveUnreferencedDims<SairFbyOp> : public OpRewritePattern<SairFbyOp> {
356352
parallel_dimensions, sequential_dimensions,
357353
rewriter.getArrayAttr({inverted_mapping.Compose(op.Init().Mapping()),
358354
inverted_mapping.Compose(op.Value().Mapping())}),
359-
op.init(), op.value(), op.memory_spaceAttr());
355+
op.init(), op.value());
360356
new_op->setDialectAttrs(op->getDialectAttrs());
361357

362358
// Replace the original op.

loop_nest.cc

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -496,30 +496,8 @@ static mlir::LogicalResult VerifySubDomains(
496496
}
497497

498498
mlir::LogicalResult VerifyLoopNests(SairProgramOp program) {
499-
// Verify operands of Sair operands are defined in the same program. This
500-
// check is performed here rather that in SairOp as it is needed for other
501-
// verifications.
502-
mlir::WalkResult result = program.walk([&](SairOp op) -> mlir::WalkResult {
503-
for (mlir::Value dimension : op.domain()) {
504-
mlir::Operation *defining_op = dimension.getDefiningOp();
505-
if (defining_op == nullptr || defining_op->getParentOp() != program) {
506-
return op.emitError()
507-
<< "sair dimensions must be defined in the region they are used";
508-
}
509-
}
510-
for (ValueOperand operand : op.ValueOperands()) {
511-
mlir::Operation *defining_op = operand.value().getDefiningOp();
512-
if (defining_op == nullptr || defining_op->getParentOp() != program) {
513-
return op.emitError()
514-
<< "sair values must be defined in the region they are used";
515-
}
516-
}
517-
return mlir::success();
518-
});
519-
if (result.wasInterrupted()) return mlir::failure();
520-
521499
// Verify loop nests are correct with regard to their operation.
522-
result = program.walk([](ComputeOp op) -> mlir::WalkResult {
500+
mlir::WalkResult result = program.walk([](ComputeOp op) -> mlir::WalkResult {
523501
if (!op.loop_nest().hasValue()) return mlir::WalkResult::advance();
524502
return VerifyLoopNestWellFormed(
525503
cast<SairOp>(op.getOperation()), op.LoopNestLoops());
@@ -569,7 +547,8 @@ mlir::LogicalResult VerifyLoopNests(SairProgramOp program) {
569547
return mlir::success();
570548
}
571549

572-
LoopFusionAnalysis::LoopFusionAnalysis(mlir::Operation *operation) {
550+
LoopFusionAnalysis::LoopFusionAnalysis(mlir::Operation *operation)
551+
: context_(operation->getContext()) {
573552
SairProgramOp program_op = dyn_cast<SairProgramOp>(operation);
574553
if (program_op == nullptr) return;
575554
mlir::LogicalResult status = Init(program_op);
@@ -579,19 +558,17 @@ LoopFusionAnalysis::LoopFusionAnalysis(mlir::Operation *operation) {
579558

580559
std::optional<LoopFusionAnalysis> LoopFusionAnalysis::Create(
581560
SairProgramOp program_op) {
582-
LoopFusionAnalysis analysis;
561+
LoopFusionAnalysis analysis(program_op->getContext());
583562
if (mlir::failed(analysis.Init(program_op))) return std::nullopt;
584563
return analysis;
585564
}
586565

587566
mlir::LogicalResult LoopFusionAnalysis::Init(SairProgramOp program_op) {
588-
mlir::MLIRContext *context = program_op.getContext();
589-
590567
llvm::SmallVector<ComputeOp> work_list;
591568
program_op.walk([&](ComputeOp op) {
592569
auto sair_op = cast<SairOp>(op.getOperation());
593570
int domain_size = sair_op.domain().size();
594-
auto none_expr = MappingNoneExpr::get(context);
571+
auto none_expr = MappingNoneExpr::get(context_);
595572
op_domain_mappings_[op.getOperation()].resize(domain_size, none_expr);
596573
if (!op.loop_nest().hasValue()) return;
597574
work_list.push_back(op);
@@ -638,9 +615,9 @@ mlir::LogicalResult LoopFusionAnalysis::Init(SairProgramOp program_op) {
638615

639616
int domain_size = fusion_class.domain.size();
640617
MappingAttr inverse_loop_nest =
641-
MappingAttr::get(context, domain_size, loop_nest).Inverse();
618+
MappingAttr::get(context_, domain_size, loop_nest).Inverse();
642619

643-
auto hr_domain = DomainShapeAttr::HyperRectangular(context, domain_size);
620+
auto hr_domain = DomainShapeAttr::HyperRectangular(context_, domain_size);
644621
DomainShapeDim loop_shape = fusion_class.iter_expr.AccessedShape(
645622
hr_domain.Dimensions(), inverse_loop_nest);
646623
if (!loop_shape.dependency_mapping().IsFullySpecified()) {
@@ -720,9 +697,23 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
720697
llvm::SmallBitVector constrained_dims =
721698
loop.iter().DependencyMask(domain_size);
722699
for (int dimension : constrained_dims.set_bits()) {
700+
// Compute the mapping to access the dimension in from loop indices.
701+
MappingAttr old_access_mapping = sair_op.shape()
702+
.Dimension(dimension)
703+
.dependency_mapping()
704+
.ResizeUseDomain(domain_size);
705+
MappingAttr new_access_mapping =
706+
loops_to_op_domain_mapping.Compose(old_access_mapping);
707+
if (!new_access_mapping.IsFullySpecified()) {
708+
return op->emitError()
709+
<< "dimension d" << dimension << " in " << loop_name.str()
710+
<< " is used before its dependencies";
711+
}
712+
713+
ValueAccess access = {sair_op.domain()[dimension], new_access_mapping};
723714
if (mlir::failed(ResolveUnificationConstraint(
724-
op, dimension, loop_name.str(), loops_to_op_domain_mapping,
725-
constraints[dimension], fusion_class.domain))) {
715+
op.getLoc(), loop_name.str(), access, constraints[dimension],
716+
fusion_class.domain))) {
726717
return mlir::failure();
727718
}
728719
}
@@ -734,4 +725,46 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
734725
return mlir::success();
735726
}
736727

728+
LoopNest LoopFusionAnalysis::GetLoopNest(
729+
llvm::ArrayRef<mlir::Attribute> loops) const {
730+
llvm::SmallVector<mlir::StringAttr> loop_names;
731+
loop_names.reserve(loops.size());
732+
for (mlir::Attribute attr : loops) {
733+
auto loop = attr.cast<LoopAttr>();
734+
loop_names.push_back(loop.name());
735+
}
736+
return GetLoopNest(loop_names);
737+
}
738+
739+
LoopNest LoopFusionAnalysis::GetLoopNest(
740+
llvm::ArrayRef<mlir::StringAttr> loop_names) const {
741+
LoopNest result;
742+
if (!loop_names.empty()) {
743+
result.domain = GetClass(loop_names.back()).domain;
744+
}
745+
746+
llvm::SmallVector<MappingExpr> iters;
747+
llvm::SmallVector<DomainShapeDim> shape_dims;
748+
iters.reserve(loop_names.size());
749+
shape_dims.reserve(loop_names.size());
750+
for (mlir::StringAttr name : loop_names) {
751+
const LoopFusionClass &fusion_class = GetClass(name);
752+
iters.push_back(fusion_class.iter_expr);
753+
754+
// Resulting loop nest dependencies are pointwise dependencies to a prefix
755+
// of the loop nest.
756+
int num_dependencies = fusion_class.dependencies.size();
757+
auto dim_type = RangeType::get(DomainShapeAttr::get(
758+
context_, llvm::makeArrayRef(shape_dims).take_front(num_dependencies)));
759+
auto dim_mapping =
760+
MappingAttr::GetIdentity(context_, num_dependencies, shape_dims.size());
761+
shape_dims.emplace_back(dim_type, dim_mapping);
762+
}
763+
result.domain_to_loops =
764+
MappingAttr::get(context_, result.domain.size(), iters);
765+
result.loops_shape = DomainShapeAttr::get(context_, shape_dims);
766+
767+
return result;
768+
}
769+
737770
} // namespace sair

loop_nest.h

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
namespace sair {
2525

26-
// Verifies loop nest attributes of operations nested in the
27-
// sair.program operation.
26+
// Verifies loop nest attributes of operations nested in the sair.program
27+
// operation. Assumes that Sair operands are defined in the same program.
2828
mlir::LogicalResult VerifyLoopNests(SairProgramOp program);
2929

3030
// Analysis of how data is distributed on loop nests iterations. It indicates,
@@ -75,14 +75,14 @@ class IterationSpaceAnalysis {
7575
// A class of fused loops.
7676
struct LoopFusionClass {
7777
// Loops this class depends on.
78-
llvm::SmallVector<mlir::StringAttr, 2> dependencies;
78+
llvm::SmallVector<mlir::StringAttr> dependencies;
7979

8080
// Domain in which the loop size is defined. This is a list of dimensions,
8181
// with an access pattern from dependencies indicies to the domain of each
8282
// dimension.
8383
//
8484
// Domains of outer fusion classes must be a prefix of this one.
85-
llvm::SmallVector<ValueAccess, 4> domain;
85+
llvm::SmallVector<ValueAccess> domain;
8686

8787
// Mapping from domain indices to the loop indices.
8888
MappingExpr iter_expr;
@@ -91,6 +91,17 @@ struct LoopFusionClass {
9191
ComputeOp occurence;
9292
};
9393

94+
// A loop nest of fused loops.
95+
// TODO: use in normalize_loops.cc
96+
struct LoopNest {
97+
// Domain used to define loop ranges.
98+
llvm::ArrayRef<ValueAccess> domain;
99+
// Mapping from `domain` to loops.
100+
MappingAttr domain_to_loops;
101+
// Shape of the resulting loop nest.
102+
DomainShapeAttr loops_shape;
103+
};
104+
94105
// Computes loop fusion classes in a sair program.
95106
class LoopFusionAnalysis {
96107
public:
@@ -106,8 +117,12 @@ class LoopFusionAnalysis {
106117
return fusion_classes_.find(name)->second;
107118
}
108119

120+
// Retrives the unified loop nest corresponding to loops.
121+
LoopNest GetLoopNest(llvm::ArrayRef<mlir::Attribute> loops) const;
122+
LoopNest GetLoopNest(llvm::ArrayRef<mlir::StringAttr> loop_names) const;
123+
109124
private:
110-
LoopFusionAnalysis() {}
125+
LoopFusionAnalysis(mlir::MLIRContext *context) : context_(context) {}
111126

112127
// Populates the analysis with the operations appearing in `program_op`.
113128
mlir::LogicalResult Init(SairProgramOp program_op);
@@ -117,6 +132,7 @@ class LoopFusionAnalysis {
117132
mlir::LogicalResult RegisterLoop(ComputeOp op, LoopAttr loop,
118133
llvm::ArrayRef<mlir::Attribute> outer_loops);
119134

135+
mlir::MLIRContext *context_;
120136
llvm::DenseMap<mlir::Attribute, LoopFusionClass> fusion_classes_;
121137
llvm::DenseMap<mlir::Operation *, llvm::SmallVector<MappingExpr, 4>>
122138
op_domain_mappings_;

sair_attributes.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,14 @@ MappingAttr MappingAttr::Canonicalize() const {
914914
return MappingAttr::get(getContext(), UseDomainSize(), exprs);
915915
}
916916

917+
int MappingAttr::MinDomainSize() const {
918+
int min = 0;
919+
for (MappingExpr expr : Dimensions()) {
920+
min = std::max(min, expr.MinDomainSize());
921+
}
922+
return min;
923+
}
924+
917925
//===----------------------------------------------------------------------===//
918926
// NamedMappingAttr
919927
//===----------------------------------------------------------------------===//
@@ -961,6 +969,12 @@ NamedMappingAttr NamedMappingAttr::get(llvm::ArrayRef<mlir::StringAttr> names,
961969
return Base::get(mapping.getContext(), names, mapping);
962970
}
963971

972+
NamedMappingAttr NamedMappingAttr::GetIdentity(
973+
mlir::MLIRContext *context, llvm::ArrayRef<mlir::StringAttr> names) {
974+
auto mapping = MappingAttr::GetIdentity(context, names.size());
975+
return NamedMappingAttr::get(names, mapping);
976+
}
977+
964978
llvm::ArrayRef<mlir::StringAttr> NamedMappingAttr::names() const {
965979
return getImpl()->names();
966980
}

sair_attributes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class MappingAttr
123123
// depends on.
124124
llvm::SmallBitVector DependencyMask() const;
125125

126+
// Minimal use domain size the mapping could have while remaining valid.
127+
int MinDomainSize() const;
128+
126129
// Indicates if the mapping accesses a single element of the def
127130
// domain per element of the use domain, when considering only the first
128131
// `num_dimensions` of the use domain.
@@ -169,6 +172,10 @@ class NamedMappingAttr
169172
static NamedMappingAttr get(llvm::ArrayRef<mlir::StringAttr> names,
170173
MappingAttr mapping);
171174

175+
// Constructs an instance of NamedMappingAttr with an identity mapping.
176+
static NamedMappingAttr GetIdentity(mlir::MLIRContext *context,
177+
llvm::ArrayRef<mlir::StringAttr> names);
178+
172179
llvm::ArrayRef<mlir::StringAttr> names() const;
173180

174181
MappingAttr mapping() const;

0 commit comments

Comments
 (0)