Skip to content

Commit 9460f6c

Browse files
ulysseBcopybara-github
authored andcommitted
Rematerialization via instances attribute.
Support generating 0 or multiple instances of a single operation in the source file by replacing the `decisions` attribute by an `instances` attribute. The `instances` attribute is an array where each entry represents a different instance of the operation in the generated code. For now, only the first instance is actually checked and generated. Lowering passes expect a single instance to be present and will emit an error otherwise. A new pass ensures that all ComputeOp have at least one instance when assigning default attribute. PiperOrigin-RevId: 395042740
1 parent ddcd1d7 commit 9460f6c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+883
-643
lines changed

canonicalization_patterns.cc

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,16 @@ class SimplifySairOperands : public RewritePattern {
132132
// Remove duplicate inputs and duplicate outputs of sair.map operations.
133133
mlir::LogicalResult DeduplicateMapInputsOutputs(
134134
SairMapOp op, mlir::PatternRewriter &rewriter) {
135+
if (op.HasCopies()) return mlir::failure();
136+
135137
int domain_size = op.domain().size();
136138
llvm::SmallVector<mlir::Value> new_operands;
137139
llvm::SmallVector<mlir::Attribute> new_mappings;
138140

139141
llvm::SmallVector<mlir::Value> old_results_to_keep;
140142
llvm::SmallVector<mlir::Value> new_scalar_results;
141143
llvm::SmallVector<mlir::Type> new_result_types;
142-
llvm::SmallVector<mlir::Attribute> new_storages;
143-
llvm::SmallVector<mlir::Attribute> new_copies;
144+
llvm::SmallBitVector remaining_outputs(op->getNumResults());
144145

145146
std::vector<int> block_args_to_erase;
146147
for (ValueOperand operand : op.ValueOperands()) {
@@ -181,7 +182,15 @@ mlir::LogicalResult DeduplicateMapInputsOutputs(
181182
// Deduplicate results.
182183
for (int j = 0; j < i; ++j) {
183184
if (scalar_value != return_op.getOperand(j)) continue;
184-
if (op.Storage(i) != op.Storage(j)) continue;
185+
bool same_storage = true;
186+
for (int k = 0, e = op.NumInstances(); k < e; ++k) {
187+
if (op.GetDecisions(k).storage() != op.GetDecisions(j).storage()) {
188+
same_storage = false;
189+
break;
190+
}
191+
}
192+
if (!same_storage) continue;
193+
185194
// Don't deduplicate with dead results that will be removed.
186195
if (op.getResult(j).use_empty()) continue;
187196
result.replaceAllUsesWith(op.getResult(j));
@@ -196,10 +205,7 @@ mlir::LogicalResult DeduplicateMapInputsOutputs(
196205
old_results_to_keep.push_back(result);
197206
new_scalar_results.push_back(scalar_value);
198207
new_result_types.push_back(result.getType());
199-
mlir::Attribute new_storage = op.Storage(i);
200-
new_storages.push_back(new_storage == nullptr ? rewriter.getUnitAttr()
201-
: new_storage);
202-
new_copies.push_back(rewriter.getArrayAttr(op.GetCopies(i)));
208+
remaining_outputs.set(i);
203209
}
204210

205211
// Create the new operation if necessary.
@@ -213,15 +219,14 @@ mlir::LogicalResult DeduplicateMapInputsOutputs(
213219
rewriter.eraseOp(return_op);
214220

215221
rewriter.setInsertionPoint(op);
216-
DecisionsAttr decisions = op.GetDecisions();
217-
auto new_decisions =
218-
DecisionsAttr::get(decisions.sequence(), decisions.loop_nest(),
219-
rewriter.getArrayAttr(new_storages),
220-
decisions.expansion(), op.getContext());
222+
mlir::ArrayAttr new_instances = MkArrayAttrMapper<DecisionsAttr>(
223+
MapStorage(MkArrayAttrFilter(remaining_outputs)))(op.instancesAttr());
224+
mlir::ArrayAttr new_copies =
225+
MkArrayAttrFilter(remaining_outputs)(op.copiesAttr());
221226
SairMapOp new_op = rewriter.create<SairMapOp>(
222227
op.getLoc(), new_result_types, op.domain(),
223228
rewriter.getArrayAttr(new_mappings), new_operands, op.shape(),
224-
new_decisions, rewriter.getArrayAttr(new_copies));
229+
new_instances, new_copies);
225230
new_op.body().takeBody(op.body());
226231

227232
for (auto [old_res, new_res] :

expansion.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class CopyExpansionPattern : public ExpansionPattern {
9090

9191
mlir::LogicalResult CopyExpansionPattern::Match(ComputeOpInstance op) const {
9292
if (op.is_copy()) return mlir::success();
93-
ComputeOp compute_op = op.AsComputeOp();
93+
ComputeOp compute_op = op.GetComputeOp();
9494
return mlir::success(isa<SairCopyOp>(compute_op.getOperation()));
9595
}
9696

expansion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class TypedExpansionPattern : public ExpansionPattern {
6666

6767
mlir::LogicalResult Match(ComputeOpInstance op) const final {
6868
if (op.is_copy()) return mlir::failure();
69-
auto cast_op = dyn_cast<OpTy>(*op.AsComputeOp());
69+
auto cast_op = dyn_cast<OpTy>(*op.GetComputeOp());
7070
if (cast_op == nullptr) return mlir::failure();
7171
return Match(cast_op);
7272
}

sair_attributes.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,30 @@ mlir::LogicalResult VerifyMappingShape(const AttrLocation &loc,
13371337
return mlir::success();
13381338
}
13391339

1340+
//===----------------------------------------------------------------------===//
1341+
// DecisionsAttr
1342+
//===----------------------------------------------------------------------===//
1343+
1344+
std::function<DecisionsAttr(DecisionsAttr)> MapLoopNest(
1345+
std::function<mlir::ArrayAttr(mlir::ArrayAttr)> loop_nest_fn) {
1346+
return [loop_nest_fn](DecisionsAttr decisions) -> DecisionsAttr {
1347+
if (decisions == nullptr) return nullptr;
1348+
return DecisionsAttr::get(
1349+
decisions.sequence(), loop_nest_fn(decisions.loop_nest()),
1350+
decisions.storage(), decisions.expansion(), decisions.getContext());
1351+
};
1352+
}
1353+
1354+
std::function<DecisionsAttr(DecisionsAttr)> MapStorage(
1355+
std::function<mlir::ArrayAttr(mlir::ArrayAttr)> storage_fn) {
1356+
return [storage_fn](DecisionsAttr decisions) -> DecisionsAttr {
1357+
if (decisions == nullptr) return nullptr;
1358+
return DecisionsAttr::get(decisions.sequence(), decisions.loop_nest(),
1359+
storage_fn(decisions.storage()),
1360+
decisions.expansion(), decisions.getContext());
1361+
};
1362+
}
1363+
13401364
} // namespace sair
13411365

13421366
#include "sair_structs.cc.inc"

sair_attributes.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,4 +556,23 @@ class MappingUnStripeExpr
556556
using namespace mlir; // NOLINT
557557
#include "sair_structs.h.inc"
558558

559+
namespace sair {
560+
561+
// Below are helper functions to manipulate DecisionsAttr. Each helper takes a
562+
// function and returns a function that applies the first function to a field of
563+
// a DecisionsAttr. Helpers return functions rather than directly applying the
564+
// transformation so that it is easier to combine transformations.
565+
566+
// Takes a function that updates a loop nest and returns a function that updates
567+
// the loop nest field of a DecisionsAttr.
568+
std::function<DecisionsAttr(DecisionsAttr)> MapLoopNest(
569+
std::function<mlir::ArrayAttr(mlir::ArrayAttr)> loop_nest_fn);
570+
571+
// Takes a function that updates a list of storages and returns a function that
572+
// updates the storage field of a DecisionsAttr.
573+
std::function<DecisionsAttr(DecisionsAttr)> MapStorage(
574+
std::function<mlir::ArrayAttr(mlir::ArrayAttr)> storage_fn);
575+
576+
} // namespace sair
577+
559578
#endif // SAIR_SAIR_ATTRIBUTES_H_

sair_base.td

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,13 @@ def SairDecisionsAttr : SairStructAttr<"DecisionsAttr", [
215215
StructFieldAttr<"expansion", OptionalAttr<StrAttr>>
216216
]>;
217217

218-
// An attribute that species copies of the result of an operation.
218+
// An attribute that specifies instances of an operation.
219+
def SairInstancesAttr
220+
: TypedArrayAttrBase<SairDecisionsAttr, "array of Sair decisions">;
221+
222+
// An attribute that specifies copies of the result of an operation.
219223
def SairCopiesAttr
220-
: TypedArrayAttrBase<
221-
TypedArrayAttrBase<SairDecisionsAttr, "array of Sair decisions">,
222-
"array of array of Sair decisions">;
224+
: TypedArrayAttrBase<SairInstancesAttr, "array of array of Sair decisions">;
223225

224226
//===----------------------------------------------------------------------===//
225227
// Sair Types
@@ -352,8 +354,7 @@ def SairOpInterface : OpInterface<"SairOp"> {
352354
dependencies accross iterations for the given !sair.value operand}],
353355
"llvm::SmallBitVector", "CarryingDimensions",
354356
(ins "int":$sair_operand), [{}], [{
355-
int size = cast<ConcreteOp>(this->getOperation()).domain().size();
356-
return llvm::SmallBitVector(size);
357+
return llvm::SmallBitVector($_op.domain().size());
357358
}]
358359
>,
359360
InterfaceMethod<
@@ -370,7 +371,12 @@ def SairOpInterface : OpInterface<"SairOp"> {
370371
[{Indicates the size of each sub-domain. The first sub-domain is always
371372
the parallel domain}],
372373
"llvm::SmallVector<int, 2>", "SubDomains", (ins), [{}]
373-
>
374+
>,
375+
InterfaceMethod<
376+
[{Indicates if the operation has exactly one instance and no copy.}],
377+
"bool", "HasExactlyOneInstance", (ins), [{}], [{
378+
return sair::HasExactlyOneInstance(this->getOperation());
379+
}]>
374380
];
375381

376382
let verify = [{return VerifySairOp(op);}];
@@ -410,34 +416,59 @@ def SairComputeOp : OpInterface<"ComputeOp"> {
410416

411417
let methods = [
412418
InterfaceMethod<
413-
"Returns lowering decisions for the operation",
414-
"DecisionsAttr", "GetDecisions", (ins), [{}], [{
415-
if (!$_op.decisions().hasValue()) {
416-
return DecisionsAttr::get(
417-
nullptr, nullptr, nullptr, nullptr, $_op.getContext());
418-
}
419-
return *$_op.decisions();
419+
"Returns the number of instances of the operation",
420+
"int", "NumInstances", (ins), [{}], [{
421+
llvm::Optional<mlir::ArrayAttr> instances = $_op.instances();
422+
if (!instances.hasValue()) return 0;
423+
return instances.getValue().size();
424+
}]>,
425+
InterfaceMethod<
426+
"Returns lowering decisions for the given operation instance",
427+
"DecisionsAttr", "GetDecisions", (ins "int":$instance), [{}], [{
428+
mlir::ArrayAttr instances = $_op.instances().getValue();
429+
return instances.getValue()[instance].cast<DecisionsAttr>();
420430
}]
421431
>,
422432
InterfaceMethod<
423433
"Sets lowering decisions for the operation",
424-
"void", "SetDecisions", (ins "DecisionsAttr":$value), [{}], [{
425-
$_op.decisionsAttr(value);
434+
"void", "SetDecisions", (ins "int":$instance, "DecisionsAttr":$value),
435+
[{}], [{
436+
auto instances = llvm::to_vector<4>(
437+
$_op.instances().getValue().getValue());
438+
instances[instance] = value;
439+
$_op.instancesAttr(mlir::ArrayAttr::get($_op.getContext(), instances));
426440
}]
427441
>,
442+
InterfaceMethod<
443+
"Appends an instance to the list of instances",
444+
"void", "AddInstance", (ins "DecisionsAttr":$value),
445+
[{}], [{
446+
llvm::SmallVector<mlir::Attribute> instances;
447+
if ($_op.instances().hasValue()) {
448+
instances = llvm::to_vector<4>(
449+
$_op.instances().getValue().getValue());
450+
}
451+
instances.push_back(value);
452+
$_op.instancesAttr(mlir::ArrayAttr::get($_op.getContext(), instances));
453+
}]
454+
>,
455+
456+
// TODO(ulysse): Legacy interface that returns decisions for the first
457+
// instance. This will be removed.
428458
InterfaceMethod<
429459
"Returns the loop nest to generate when lowering the operation",
430460
"llvm::Optional<mlir::ArrayAttr>", "loop_nest", (ins), [{}], [{
431-
if (!$_op.decisions().hasValue()) return llvm::None;
432-
if ($_op.decisions()->loop_nest() == nullptr) return llvm::None;
433-
return $_op.decisions()->loop_nest();
461+
if ($_op.NumInstances() < 1) return llvm::None;
462+
DecisionsAttr decisions = $_op.GetDecisions(0);
463+
if (decisions.loop_nest() == nullptr) return llvm::None;
464+
return decisions.loop_nest();
434465
}]
435466
>,
436467
InterfaceMethod<
437468
"Sets the loop nest to generate when lowering the operation",
438469
"void", "setLoopNest", (ins "mlir::ArrayAttr":$loop_nest), [{
439-
DecisionsAttr decisions = $_op.GetDecisions();
440-
$_op.SetDecisions(DecisionsAttr::get(
470+
DecisionsAttr decisions = $_op.GetDecisions(0);
471+
$_op.SetDecisions(0, DecisionsAttr::get(
441472
decisions.sequence(), loop_nest, decisions.storage(),
442473
decisions.expansion(), $_op.getContext()));
443474
}]
@@ -455,9 +486,10 @@ def SairComputeOp : OpInterface<"ComputeOp"> {
455486
InterfaceMethod<
456487
"Returns the storage of the values produced by the operation",
457488
"Optional<mlir::ArrayAttr>", "storage", (ins), [{}], [{
458-
if (!$_op.decisions().hasValue()) return llvm::None;
459-
if ($_op.decisions()->storage() == nullptr) return llvm::None;
460-
return $_op.decisions()->storage();
489+
if ($_op.NumInstances() < 1) return llvm::None;
490+
DecisionsAttr decisions = $_op.GetDecisions(0);
491+
if (decisions.storage() == nullptr) return llvm::None;
492+
return decisions.storage();
461493
}]
462494
>,
463495
InterfaceMethod<
@@ -472,7 +504,7 @@ def SairComputeOp : OpInterface<"ComputeOp"> {
472504
InterfaceMethod<
473505
"Sets the storage of the values produced by the operation",
474506
"void", "SetStorage", (ins "int":$result, "BufferAttr":$buffer), [{}], [{
475-
DecisionsAttr decisions = $_op.GetDecisions();
507+
DecisionsAttr decisions = $_op.GetDecisions(0);
476508
llvm::SmallVector<mlir::Attribute> values;
477509
int num_results = $_op->getNumResults();
478510
if (decisions.storage() == nullptr) {
@@ -482,25 +514,26 @@ def SairComputeOp : OpInterface<"ComputeOp"> {
482514
}
483515
values[result] = buffer;
484516
auto new_storage = mlir::ArrayAttr::get($_op.getContext(), values);
485-
$_op.SetDecisions(DecisionsAttr::get(
517+
$_op.SetDecisions(0, DecisionsAttr::get(
486518
decisions.sequence(), decisions.loop_nest(), new_storage,
487519
decisions.expansion(), $_op.getContext()));
488520
}]
489521
>,
490522
InterfaceMethod<
491523
"Returns the sequence number of this operation as signed integer.",
492524
"::llvm::Optional<int64_t>", "Sequence", (ins), [{}], [{
493-
if (!$_op.decisions().hasValue()) return llvm::None;
494-
if ($_op.decisions()->sequence() == nullptr) return llvm::None;
495-
return $_op.decisions()->sequence().getInt();
525+
if ($_op.NumInstances() < 1) return llvm::None;
526+
DecisionsAttr decisions = $_op.GetDecisions(0);
527+
if (decisions.sequence() == nullptr) return ::llvm::None;
528+
return decisions.sequence().getInt();
496529
}]
497530
>,
498531
InterfaceMethod<
499532
"Sets the sequence number of the operation to the given value.",
500533
"void", "SetSequence", (ins "int64_t":$seq), [{}], [{
501534
::mlir::Builder builder($_op->getContext());
502-
DecisionsAttr decisions = GetDecisions();
503-
$_op.SetDecisions(DecisionsAttr::get(
535+
DecisionsAttr decisions = GetDecisions(0);
536+
$_op.SetDecisions(0, DecisionsAttr::get(
504537
builder.getI64IntegerAttr(seq), decisions.loop_nest(),
505538
decisions.storage(), decisions.expansion(),
506539
$_op.getContext()));
@@ -511,7 +544,7 @@ def SairComputeOp : OpInterface<"ComputeOp"> {
511544
let verify = [{return sair::VerifyComputeOp(op);}];
512545

513546
let extraClassDeclaration = [{
514-
static constexpr llvm::StringRef kDecisionsAttrName = "decisions";
547+
static constexpr llvm::StringRef kInstancesAttrName = "instances";
515548
}];
516549
}
517550

0 commit comments

Comments
 (0)