Skip to content

Commit 89a623b

Browse files
ulysseBcopybara-github
authored andcommitted
Placeholder operations for defining dummy domains.
Introduce a new PlaceholderOp operation that defines a range of unknown size. This operation is used when introducing new operations during code generation. In many cases, the domain of new operations will not be used, either: * because the operation is a ComputeOp with a loop-nest and loops sizes are already defined by other operations. * because the operation is not a compute operation so its domain won't be used to introduce loops or buffers. PiperOrigin-RevId: 361100458
1 parent 9f7d367 commit 89a623b

File tree

10 files changed

+148
-5
lines changed

10 files changed

+148
-5
lines changed

canonicalization_patterns.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,11 @@ void SairDynRangeOp::getCanonicalizationPatterns(
417417
patterns.insert<SimplifySairOperands>();
418418
}
419419

420+
void SairPlaceholderOp::getCanonicalizationPatterns(
421+
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
422+
patterns.insert<SimplifySairOperands>();
423+
}
424+
420425
void SairStaticRangeOp::getCanonicalizationPatterns(
421426
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
422427
patterns.insert<SimplifySairOperands>();

sair_ops.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,22 @@ ParseResult ParseStaticRangeOp(mlir::OpAsmParser &parser,
201201
parser.addTypeToList(type, result.types));
202202
}
203203

204+
// Parses the placeholder dimension. This operation has an iteration domain and
205+
// returns a range value. The syntax is the following.
206+
//
207+
// placeholder-op ::= `sair.placeholder` domain : range-type
208+
//
209+
ParseResult ParsePlaceholderOp(mlir::OpAsmParser &parser,
210+
mlir::OperationState &result) {
211+
llvm::SmallVector<mlir::OpAsmParser::OperandType> domain;
212+
RangeType type;
213+
214+
return mlir::failure(ParseDomain(parser, domain) ||
215+
parser.parseColonType<RangeType>(type) ||
216+
parser.addTypeToList(type, result.types) ||
217+
ResolveDomain(parser, type.Shape(), domain, result));
218+
}
219+
204220
// Parses the copy operation. This operation has an iteration domain and
205221
// accesses a single Sair value. The syntax for the operation is the following.
206222
//
@@ -616,6 +632,12 @@ void Print(SairStaticRangeOp op, OpAsmPrinter &printer) {
616632
printer << " : " << op.getType();
617633
}
618634

635+
static void Print(SairPlaceholderOp op, mlir::OpAsmPrinter &printer) {
636+
printer << SairPlaceholderOp::getOperationName();
637+
PrintDomain(op.domain(), printer);
638+
printer << " : " << op.range().getType();
639+
}
640+
619641
// Prints the copy operation.
620642
void Print(SairCopyOp op, OpAsmPrinter &printer) {
621643
printer << SairCopyOp::getOperationName();
@@ -892,7 +914,7 @@ llvm::SmallBitVector SairStoreToMemRefOp::DimsDependingOnOperand(
892914

893915
ParseResult ParseDomain(
894916
mlir::OpAsmParser &parser,
895-
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> &dimensions) {
917+
llvm::SmallVectorImpl<mlir::OpAsmParser::OperandType> &dimensions) {
896918
if (failed(parser.parseOptionalLSquare())) return success();
897919
do {
898920
std::string dim_name = "d" + std::to_string(dimensions.size());
@@ -1559,6 +1581,10 @@ llvm::SmallVector<int, 2> SairDynRangeOp::SubDomains() {
15591581
return {static_cast<int>(domain().size())};
15601582
}
15611583

1584+
llvm::SmallVector<int, 2> SairPlaceholderOp::SubDomains() {
1585+
return {static_cast<int>(domain().size())};
1586+
}
1587+
15621588
llvm::SmallVector<int, 2> SairCopyOp::SubDomains() {
15631589
return {static_cast<int>(domain().size())};
15641590
}
@@ -1672,6 +1698,14 @@ SairOp SairStaticRangeOp::ReCreateWithNewDomain(
16721698
"not called by NormalizeLoops because the op defines a dimension");
16731699
}
16741700

1701+
SairOp SairPlaceholderOp::ReCreateWithNewDomain(
1702+
llvm::ArrayRef<llvm::SmallVector<mlir::Value, 4>> new_domains,
1703+
DomainShapeAttr new_shape, MappingAttr new_to_old_mapping,
1704+
mlir::OpBuilder &builder) {
1705+
llvm_unreachable(
1706+
"not called by NormalizeLoops because the op defines a dimension");
1707+
}
1708+
16751709
SairOp SairCopyOp::ReCreateWithNewDomain(
16761710
llvm::ArrayRef<llvm::SmallVector<mlir::Value, 4>> new_domains,
16771711
DomainShapeAttr new_shape, MappingAttr new_to_old_mapping,

sair_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ namespace sair {
5353
// '['.
5454
ParseResult ParseDomain(
5555
mlir::OpAsmParser &parser,
56-
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> &dimensions);
56+
llvm::SmallVectorImpl<mlir::OpAsmParser::OperandType> &dimensions);
5757
// Resolves the operands that consitute the dimensions of an iteration domain
5858
// and registers them in 'result'.
5959
ParseResult ResolveDomain(mlir::OpAsmParser &parser,

sair_ops.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,30 @@ def SairStaticRangeOp : SairOp<"static_range", [NoSideEffect, SairRangeOp]> {
106106
}];
107107
}
108108

109+
def SairPlaceholderOp : SairOp<"placeholder", [NoSideEffect]> {
110+
let summary = "Placeholder for an iteration dimension";
111+
112+
let description = [{
113+
Defines an iteration dimension that will be replaced by an actual iteration
114+
dimension during loop-normalization pass. This is used to introduce
115+
operations in specific loop nests before dimensions defining loop ranges are
116+
introduced.
117+
118+
If a dimension bound to a placeholder operation appears in loop iterators
119+
mapping, the dimension must be defined by another operation with a
120+
non-placeholder operation. In that sense, they behave similarly to `none` in
121+
loop-nest description.
122+
}];
123+
124+
let arguments = (ins Variadic<SairRange>:$domain);
125+
let results = (outs SairRange:$range);
126+
127+
let parser = [{return ParsePlaceholderOp(parser, result);}];
128+
let printer = [{return Print(*this, p);}];
129+
130+
DerivedAttr shape = SairResultDomainShapeAttr;
131+
}
132+
109133
def SairCopyOp : SairOp<"copy", [
110134
NoSideEffect,
111135
SairComputeOp,

test/introduce_loops_invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,27 @@ func @size_not_in_register(%arg0: index) {
135135
}
136136
return
137137
}
138+
139+
// -----
140+
141+
func @placeholder() {
142+
sair.program {
143+
%0 = sair.static_range 8 : !sair.range
144+
// expected-error @+1 {{placeholders must be replaced by actual dimensions before introducing loops}}
145+
%1 = sair.placeholder : !sair.range
146+
sair.map[d0:%1] attributes {
147+
loop_nest = [{name = "loopA", iter = #sair.mapping_expr<d0>}]
148+
} {
149+
^bb0(%arg1: index):
150+
sair.return
151+
} : #sair.shape<d0:range>, () -> ()
152+
sair.map[d0:%0] attributes {
153+
loop_nest = [{name = "loopA", iter = #sair.mapping_expr<d0>}]
154+
} {
155+
^bb0(%arg1: index):
156+
sair.return
157+
} : #sair.shape<d0:range>, () -> ()
158+
sair.exit
159+
}
160+
return
161+
}

test/invalid.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,3 +1519,18 @@ func @map_reduce_init_result_storages_must_match(%arg0: f32) {
15191519
}
15201520
return
15211521
}
1522+
1523+
// -----
1524+
1525+
func @placeholder_loop_nest_unspecified(%arg0: f32) {
1526+
sair.program {
1527+
%0 = sair.from_scalar %arg0 : !sair.value<(), f32>
1528+
%1 = sair.placeholder : !sair.range
1529+
// expected-error @+1 {{loop "loopA" iterator is not fully specified}}
1530+
%2 = sair.copy[d0:%1] %0 {
1531+
loop_nest = [{name = "loopA", iter = #sair.mapping_expr<d0>}]
1532+
} : !sair.value<d0:range, f32>
1533+
sair.exit
1534+
}
1535+
return
1536+
}

test/roundtrip.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,32 @@ func @storage_stripe(%arg0: f32) {
627627
}
628628
return
629629
}
630+
631+
// CHECK-LABEL: @placeholder
632+
func @placeholder(%arg0: f32) {
633+
sair.program {
634+
%0 = sair.static_range 2 : !sair.range
635+
%1 = sair.placeholder[d0:%0] : !sair.range<d0:range>
636+
%2 = sair.from_scalar %arg0 : !sair.value<(), f32>
637+
%3 = sair.copy[d0:%0, d1:%1] %2 : !sair.value<d0:range x d1:range(d0), f32>
638+
sair.exit
639+
}
640+
return
641+
}
642+
643+
// CHECK-LABEL: @placeholder_with_loop_nest
644+
func @placeholder_with_loop_nest(%arg0: f32) {
645+
sair.program {
646+
%0 = sair.static_range 2 : !sair.range
647+
%1 = sair.placeholder : !sair.range
648+
%2 = sair.from_scalar %arg0 : !sair.value<(), f32>
649+
%3 = sair.copy[d0:%0] %2 {
650+
loop_nest = [{name = "loopA", iter = #sair.mapping_expr<d0>}]
651+
} : !sair.value<d0:range, f32>
652+
%4 = sair.copy[d0:%1] %2 {
653+
loop_nest = [{name = "loopA", iter = #sair.mapping_expr<d0>}]
654+
} : !sair.value<d0:range, f32>
655+
sair.exit
656+
}
657+
return
658+
}

transforms/introduce_loops.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,14 @@ mlir::LogicalResult IntroduceLoop(SairMapOp op,
496496
LoopAttr loop = loop_nest.back().cast<LoopAttr>();
497497

498498
int dimension = loop.iter().cast<MappingDimExpr>().dimension();
499-
RangeOp range = cast<RangeOp>(op.domain()[dimension].getDefiningOp());
499+
mlir::Operation *dimension_op = op.domain()[dimension].getDefiningOp();
500+
if (isa<SairPlaceholderOp>(dimension_op)) {
501+
return dimension_op->emitError()
502+
<< "placeholders must be replaced by actual dimensions before "
503+
"introducing loops";
504+
}
505+
506+
RangeOp range = cast<RangeOp>(dimension_op);
500507
MappingAttr range_mapping =
501508
op.shape().Dimension(dimension).dependency_mapping().ResizeUseDomain(
502509
op.domain().size() - 1);

transforms/normalize_loops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ class NormalizeLoopsPass : public NormalizeLoopsPassBase<NormalizeLoopsPass> {
272272
llvm::DenseMap<mlir::Attribute, std::pair<mlir::Value, DomainShapeDim>>
273273
loop_range_cache;
274274
program.walk([&](SairOp op) {
275-
// Do not normalize range operations.
276-
if (isa<RangeOp>(op.getOperation())) return;
275+
// Do not normalize range and placeholder operations.
276+
if (isa<RangeOp, SairPlaceholderOp>(op.getOperation())) return;
277277
if (mlir::failed(NormalizeLoops(op, iteration_spaces, builder,
278278
loop_range_cache))) {
279279
signalPassFailure();

util.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "llvm/ADT/STLExtras.h"
1818
#include "sair_op_interfaces.h"
19+
#include "sair_ops.h"
1920

2021
namespace sair {
2122

@@ -79,6 +80,10 @@ mlir::LogicalResult ResolveUnificationConstraint(
7980
llvm::SmallVectorImpl<ValueAccess> &target_domain) {
8081
mlir::MLIRContext *context = constraint.getContext();
8182

83+
// Ignore placeholders.
84+
mlir::Operation *defining_op = dimension.value.getDefiningOp();
85+
if (isa<SairPlaceholderOp>(defining_op)) return mlir::success();
86+
8287
if (constraint.isa<MappingNoneExpr>()) {
8388
constraint = MappingDimExpr::get(target_domain.size(), context);
8489
target_domain.push_back(dimension);

0 commit comments

Comments
 (0)