-
Notifications
You must be signed in to change notification settings - Fork 186
Expand file tree
/
Copy pathStablehloAggressiveSimplification.cpp
More file actions
1720 lines (1452 loc) · 65.8 KB
/
StablehloAggressiveSimplification.cpp
File metadata and controls
1720 lines (1452 loc) · 65.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License, Version 2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements optional canonicalization patterns for StableHLO ops.
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <type_traits>
#include <utility>
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/Base.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/PassUtils.h"
#include "stablehlo/transforms/optimization/Passes.h"
using llvm::SmallBitVector;
namespace mlir::stablehlo {
#define GEN_PASS_DEF_STABLEHLOAGGRESSIVESIMPLIFICATIONPASS
#include "stablehlo/transforms/optimization/Passes.h.inc"
namespace {
static constexpr StablehloAggressiveSimplificationPassOptions kDefaultOptions;
static bool isIotaRange(ArrayRef<int64_t> dims) {
return llvm::all_of(llvm::enumerate(dims), [](const auto& it) {
return static_cast<int64_t>(it.index()) == it.value();
});
}
bool mergeDiscardableAttributes(Value sourceValue, Value destValue) {
Operation* sourceOp = sourceValue.getDefiningOp();
Operation* destOp = destValue.getDefiningOp();
if (!sourceOp || !destOp) return false;
auto sourceAttrs = sourceOp->getDiscardableAttrDictionary();
if (!sourceAttrs) return true;
auto destAttrs = destOp->getDiscardableAttrDictionary();
if (!destAttrs) {
destOp->setDiscardableAttrs(sourceAttrs);
return true;
}
NamedAttrList mergedAttrs(destAttrs);
for (auto attr : sourceAttrs.getValue()) {
if (attr.getName() == "mhlo.frontend_attributes" &&
mergedAttrs.get("mhlo.frontend_attributes")) {
// Merge frontend attributes, prioritizing source attributes.
auto destFrontendAttrs =
cast<DictionaryAttr>(mergedAttrs.get("mhlo.frontend_attributes"));
auto sourceFrontendAttrs = cast<DictionaryAttr>(attr.getValue());
NamedAttrList frontendAttrs(destFrontendAttrs);
for (auto sourceAttr : sourceFrontendAttrs) {
frontendAttrs.set(sourceAttr.getName(), sourceAttr.getValue());
}
mergedAttrs.set("mhlo.frontend_attributes",
frontendAttrs.getDictionary(destOp->getContext()));
} else {
// Otherwise prioritize source attributes
mergedAttrs.set(attr.getName(), attr.getValue());
}
}
destOp->setDiscardableAttrs(mergedAttrs.getDictionary(destOp->getContext()));
return true;
}
template <typename OpType>
struct SimplifyOpRewritePattern : OpRewritePattern<OpType> {
SimplifyOpRewritePattern(
MLIRContext* context,
const StablehloAggressiveSimplificationPassOptions& options,
PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {})
: OpRewritePattern<OpType>(context, benefit, generatedNames),
options(options) {}
// Prevent `options` from binding to a temporary.
SimplifyOpRewritePattern(
MLIRContext* context,
StablehloAggressiveSimplificationPassOptions&& options,
PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {}) = delete;
const StablehloAggressiveSimplificationPassOptions& options;
};
/// Matches when either of the submatchers match.
template <typename MatcherA, typename MatcherB>
struct m_AnyOf {
m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {}
bool match(Operation* op) { return matcherA.match(op) || matcherB.match(op); }
MatcherA matcherA;
MatcherB matcherB;
};
template <typename MatcherA, typename MatcherB>
m_AnyOf(MatcherA, MatcherB) -> m_AnyOf<MatcherA, MatcherB>;
/// Matches when either of the submatchers match.
template <typename MatcherA, typename MatcherB>
struct m_AnyAttrOf {
m_AnyAttrOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {}
bool match(Attribute attr) {
return matcherA.match(attr) || matcherB.match(attr);
}
MatcherA matcherA;
MatcherB matcherB;
};
template <typename MatcherA, typename MatcherB>
m_AnyAttrOf(MatcherA, MatcherB) -> m_AnyAttrOf<MatcherA, MatcherB>;
//////////////////////////////////
// CompareOp
/////////////////////////////////
static ComparisonDirection invertDirection(ComparisonDirection direction) {
switch (direction) {
case ComparisonDirection::EQ:
case ComparisonDirection::NE:
return direction;
case ComparisonDirection::GE:
return ComparisonDirection::LE;
case ComparisonDirection::GT:
return ComparisonDirection::LT;
case ComparisonDirection::LE:
return ComparisonDirection::GE;
case ComparisonDirection::LT:
return ComparisonDirection::GT;
}
llvm::report_fatal_error(llvm::formatv(
"Undefined enum value for `ComparisonDirection`: {0}",
static_cast<std::underlying_type_t<ComparisonDirection>>(direction)));
}
struct CompareOpCanon final : SimplifyOpRewritePattern<CompareOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(CompareOp op,
PatternRewriter& rewriter) const override {
RankedTensorType type = op.getType();
// Bail out on non-integer comparison.
// TODO: Support more comparison types.
std::optional<ComparisonType> compType = op.getCompareType();
if (!compType ||
!llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED},
*compType)) {
return failure();
}
ComparisonDirection direction = op.getComparisonDirection();
Value lhs = op.getLhs();
Value rhs = op.getRhs();
// Pattern: compare(X, X, [EQ,GE,LE]) -> true
// Pattern: compare(X, X, [NE,GT,LT]) -> false
if (lhs == rhs) {
switch (direction) {
case ComparisonDirection::EQ:
case ComparisonDirection::GE:
case ComparisonDirection::LE: {
rewriter.replaceOpWithNewOp<ConstantOp>(
op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
return success();
}
case ComparisonDirection::GT:
case ComparisonDirection::LT:
case ComparisonDirection::NE: {
rewriter.replaceOpWithNewOp<ConstantOp>(op,
rewriter.getZeroAttr(type));
return success();
}
}
llvm_unreachable("Unhandled case");
}
// Pattern: compare(cst, X, comparator) -> compare(X, cst, inv(comparator))
TypedAttr lhsAttr, rhsAttr;
matchPattern(lhs, m_Constant(&lhsAttr));
matchPattern(rhs, m_Constant(&rhsAttr));
// The canonical form has the constant operand as the RHS.
if (lhsAttr && !rhsAttr) {
rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] {
op.setComparisonDirection(invertDirection(direction));
op->setOperands(ValueRange{rhs, lhs});
});
return success();
}
return failure();
}
};
//////////////////////////////////
// ConcatenateOp
/////////////////////////////////
// Pattern: concatenate(X) -> X
class ConcatenateOpNoop : public SimplifyOpRewritePattern<ConcatenateOp> {
public:
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
if (op.getInputs().size() != 1 ||
op.getInputs().front().getType() != op.getType())
return rewriter.notifyMatchFailure(op, "not single operand noop-concat");
rewriter.replaceOp(op, op.getInputs().front());
return success();
}
};
// Pattern: concatenate(X, Y, []) -> concatenate(X, Y)
class ConcatenateOpRemoveEmpty
: public SimplifyOpRewritePattern<ConcatenateOp> {
public:
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
auto axis = op.getDimension();
llvm::SmallVector<Value> newOperands = llvm::to_vector(
llvm::make_filter_range(op.getOperands(), [&](Value operand) {
return cast<ShapedType>(operand.getType()).getDimSize(axis) != 0;
}));
// Only handle nonempty new operands, empty handled by
// ZeroExtentToEmptyConstant pattern.
if (!newOperands.empty() && newOperands.size() < op.getNumOperands()) {
rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); });
return success();
}
return failure();
}
};
// Pattern: concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z)
class ConcatenateOpFlatten : public SimplifyOpRewritePattern<ConcatenateOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
auto getFlattenedOperands = [&](const Value& val) -> ValueRange {
auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp());
// To avoid inflate the memory footprint, only flatten the
// ConcatenateOp when it has only one use.
if (definingOp && definingOp->hasOneUse() &&
definingOp.getDimension() == op.getDimension())
return definingOp.getInputs();
return val;
};
bool needToFlatten = false;
int operandCount = 0;
llvm::for_each(op.getInputs(), [&](Value val) {
auto result = getFlattenedOperands(val);
if (result.size() != 1 || result[0] != val) needToFlatten = true;
operandCount += result.size();
});
if (!needToFlatten)
return rewriter.notifyMatchFailure(op, "no need to flatten");
llvm::SmallVector<Value, 6> newOperands;
newOperands.reserve(operandCount);
for (auto operand : op.getInputs()) {
auto flattenedOperands = getFlattenedOperands(operand);
newOperands.append(flattenedOperands.begin(), flattenedOperands.end());
}
rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); });
return success();
}
};
/////////////////////////////////
// CustomCallOp
/////////////////////////////////
struct CustomCallUnregisteredBackendConfigToFfi final
: SimplifyOpRewritePattern<CustomCallOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(CustomCallOp op,
PatternRewriter& rewriter) const override {
constexpr StringRef kMhloBackendConfigAttrName = "mhlo.backend_config";
if (op.getApiVersion() != CustomCallApiVersion::API_VERSION_ORIGINAL)
return rewriter.notifyMatchFailure(
op, "Only match `custom_call` ops with `API_VERSION_ORIGINAL`.");
auto mhloBackendConfigAttr = op->getAttr(kMhloBackendConfigAttrName);
if (!mhloBackendConfigAttr)
return rewriter.notifyMatchFailure(
op, "No `mhlo.backend_config` attribute to fix.");
auto oldBackendConfig =
dyn_cast<StringAttr>(op.getBackendConfigOrDefault());
if (!oldBackendConfig)
return rewriter.notifyMatchFailure(
op, "`op.getBackendConfigOrDefault()` didn't return a `StringAttr`.");
if (!oldBackendConfig.empty())
return rewriter.notifyMatchFailure(
op, "Non-empty `backend_config` attribute shouldn't be overwritten.");
op.setBackendConfigAttr(mhloBackendConfigAttr);
op.setApiVersion(CustomCallApiVersion::API_VERSION_TYPED_FFI);
op->removeAttr(kMhloBackendConfigAttrName);
return success();
}
};
//////////////////////////////////
// BroadcastInDimOp
/////////////////////////////////
// Used in DRR file.
// Convert broadcast dimensions into permutation for transpose.
DenseI64ArrayAttr getInvertedBroadcastDimensions(OpBuilder& b,
ArrayRef<int64_t> dims) {
SmallVector<int64_t> permutation(dims.size());
for (size_t i = 0; i < dims.size(); ++i) {
permutation[dims[i]] = i;
}
return b.getDenseI64ArrayAttr(permutation);
}
DenseI64ArrayAttr getMergedBroadcastDimensions(OpBuilder& b,
ArrayRef<int64_t> dims,
ArrayRef<int64_t> dimsParent) {
auto mergedDims = llvm::map_to_vector(
dimsParent, [&dims](int64_t dim) { return dims[dim]; });
return b.getDenseI64ArrayAttr(mergedDims);
}
//////////////////////////////////
// DynamicBroadcastInDimOp
/////////////////////////////////
/// Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist.
///
/// Sometimes, we want to replace an op with a new op and simultaneously refine
/// the result type from a dynamically-shaped type to a statically-shaped type.
/// (Search for usages of this function for examples).
//
/// Oftentimes, this works just fine because HLO is designed to accommodate
/// this kind of type refinements. But sometimes, this doesn't work - when
/// the op is used outside of the HLO dialect (e.g. in func.return). In these
/// cases, we insert a stablehlo.convert to smooth things out.
template <typename OpTy, typename... Args>
static OpTy refineOpWithNewOp(PatternRewriter& rewriter, Operation* op,
Args&&... args) {
auto newOp =
OpTy::create(rewriter, op->getLoc(), std::forward<Args>(args)...);
llvm::SmallVector<Value> replacementResults;
assert(op->getNumResults() == newOp->getNumResults() &&
"replacement op doesn't match results of original op");
for (auto [opResult, newOpResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
Value replacementResult = newOpResult;
if (llvm::any_of(opResult.getUsers(), [&](Operation* user) {
return user->getDialect() != op->getDialect();
}))
replacementResult = ConvertOp::create(rewriter, op->getLoc(),
opResult.getType(), newOpResult);
replacementResults.push_back(replacementResult);
}
rewriter.replaceOp(op, replacementResults);
return newOp;
}
/// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
/// BroadcastInDimOp.
struct DynamicBroadcastInDimOpNotActuallyDynamic final
: SimplifyOpRewritePattern<DynamicBroadcastInDimOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
PatternRewriter& rewriter) const override {
RankedTensorType operandType = op.getOperand().getType();
if (!operandType.hasStaticShape())
return rewriter.notifyMatchFailure(op, "requires operand static shape");
RankedTensorType type = op.getType();
// output has static shape, replace with broadcast_in_dim
if (type.hasStaticShape()) {
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
op, type, op.getOperand(), op.getBroadcastDimensionsAttr());
return success();
}
// output_dimensions are constant, set output shape with output_dimensions,
// then replace with broadcast_in_dim
if (llvm::SmallVector<int64_t> shape;
succeeded(hlo::matchInts(op.getOutputDimensions(), shape))) {
refineOpWithNewOp<BroadcastInDimOp>(
rewriter, op, RankedTensorType::get(shape, type.getElementType()),
op.getOperand(), op.getBroadcastDimensionsAttr());
return success();
}
return rewriter.notifyMatchFailure(
op, "requires output static shape or constant broadcast dimensions");
}
};
//////////////////////////////////
// DynamicGatherOp
/////////////////////////////////
DenseI64ArrayAttr convertToI64Array(OpBuilder& b, Attribute attr) {
auto denseAttr = cast<ElementsAttr>(attr);
SmallVector<int64_t> result;
result.reserve(denseAttr.getNumElements());
for (auto elem : denseAttr.getValues<APInt>())
result.push_back(elem.getSExtValue());
return b.getDenseI64ArrayAttr(result);
}
//////////////////////////////////
// DynamicIotaOp
/////////////////////////////////
struct DynamicIotaIsStatic : public SimplifyOpRewritePattern<DynamicIotaOp> {
using SimplifyOpRewritePattern<DynamicIotaOp>::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
// Result type has static shape, replace with iota.
auto resultTy = cast<ShapedType>(iota.getType());
if (!resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(iota, "requires output static shape");
rewriter.replaceOpWithNewOp<IotaOp>(iota, resultTy,
iota.getIotaDimension());
return success();
}
};
// Dynamic Iota operations across multiple dimensions can be reduced to an iota
// and a ranked broadcast.
// Pattern: dynamic_iota(shape, dim) ->
// dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape)
struct DynamicIotaOpToBroadcast
: public SimplifyOpRewritePattern<DynamicIotaOp> {
using SimplifyOpRewritePattern<DynamicIotaOp>::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
auto resultType = cast<ShapedType>(iota.getType());
if (resultType.getRank() < 2)
return rewriter.notifyMatchFailure(iota, "requires rank >= 2");
Location iotaLoc = iota.getLoc();
auto iotaDimension = static_cast<int64_t>(iota.getIotaDimension());
Value iotaShape = iota.getOutputShape();
auto iotaShapeType = cast<ShapedType>(iotaShape.getType());
if (iotaShapeType.getElementType().isIndex())
return rewriter.notifyMatchFailure(
iota, "index-typed shapes not supported; run "
"shape-legalize-to-stablehlo first");
auto iotaShapeI64Type =
RankedTensorType::get(iotaShapeType.getShape(), rewriter.getI64Type());
Value iotaShapeI64;
if (iotaShapeType.getElementType().isInteger(64)) {
iotaShapeI64 = iotaShape;
} else {
iotaShapeI64 = stablehlo::ConvertOp::create(rewriter, iotaLoc,
iotaShapeI64Type, iotaShape);
}
auto iotaDimensionSize =
SliceOp::create(rewriter, iotaLoc, iotaShapeI64,
rewriter.getDenseI64ArrayAttr(iotaDimension),
rewriter.getDenseI64ArrayAttr(iotaDimension + 1),
rewriter.getDenseI64ArrayAttr(1));
auto preBroadcastResultType = RankedTensorType::get(
{resultType.getDimSize(iotaDimension)}, resultType.getElementType());
auto preBroadcastResult =
DynamicIotaOp::create(rewriter, iotaLoc, preBroadcastResultType,
iotaDimensionSize, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
iota, resultType, preBroadcastResult, iotaShape,
rewriter.getDenseI64ArrayAttr(iotaDimension));
return success();
}
};
//////////////////////////////////
// DynamicReshapeOp
/////////////////////////////////
struct DynamicReshapeOpIsStatic final
: SimplifyOpRewritePattern<DynamicReshapeOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
// This is a noop when the output type is already a static shape.
RankedTensorType type = op.getType();
if (!type.hasStaticShape())
return rewriter.notifyMatchFailure(op, "dynamic reshape not static");
rewriter.replaceOpWithNewOp<ReshapeOp>(op, type, op.getOperand());
return success();
}
};
// Pattern: dynamic_reshape(op(dynamic_reshape(X, shape)), shape)
// -> op(dynamic_reshape(X, shape))
// [if op has same operand and result shape]
class DynamicReshapeOpSameOperandAndResultShape
: public SimplifyOpRewritePattern<DynamicReshapeOp> {
public:
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
Operation* defOp = op.getOperand().getDefiningOp();
if (!defOp ||
!defOp->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
return rewriter.notifyMatchFailure(
op, "dynamic reshape parent not same operand and result shape");
}
DynamicReshapeOp reshape =
defOp->getOperand(0).getDefiningOp<DynamicReshapeOp>();
if (!reshape)
return rewriter.notifyMatchFailure(
op, "dynamic reshape not wrapping same operand and result shape");
if (reshape.getOutputShape() == op.getOutputShape()) {
rewriter.replaceOp(op, {defOp->getResult(0)});
return success();
}
return failure();
}
};
//////////////////////////////////
// DynamicSliceOp
/////////////////////////////////
// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
// This canonicalization is applied the case when the `begin` input values are
// compile time constants and thus can be made into a tensor.
//
// Pattern: dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes)
struct DynamicSliceOpToSlice : public SimplifyOpRewritePattern<DynamicSliceOp> {
using SimplifyOpRewritePattern<DynamicSliceOp>::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(DynamicSliceOp dynamicSlice,
PatternRewriter& rewriter) const override {
Value input = dynamicSlice.getOperand();
auto inputType = cast<ShapedType>(input.getType());
if (!inputType.hasStaticShape())
return rewriter.notifyMatchFailure(dynamicSlice,
"dynamic slice input not static");
auto sliceSizes = dynamicSlice.getSliceSizes();
SmallVector<int64_t, 4> tempStartIndices;
for (const auto& indexAndSliceStart :
llvm::enumerate(dynamicSlice.getStartIndices())) {
APInt val;
Value start = indexAndSliceStart.value();
int64_t index = indexAndSliceStart.index();
if (!matchPattern(start, m_ConstantInt(&val)))
return rewriter.notifyMatchFailure(dynamicSlice,
"dynamic slice input not constant");
// Clamp the indices within bounds to faithfully mirror dynamic slice
// semantics.
int64_t clampedStart =
std::clamp(val.getSExtValue(), static_cast<int64_t>(0),
inputType.getDimSize(index) - sliceSizes[index]);
tempStartIndices.push_back(clampedStart);
}
// At this point we've determined that the start indices are all constants;
// pack them into a single tensor.
auto sliceStartIndices = rewriter.getDenseI64ArrayAttr(tempStartIndices);
SmallVector<int64_t, 4> tempSliceLimits;
for (const auto& [start, size] : llvm::zip(tempStartIndices, sliceSizes)) {
tempSliceLimits.push_back(start + size);
}
auto sliceLimits = rewriter.getDenseI64ArrayAttr(tempSliceLimits);
auto sliceStrides = rewriter.getDenseI64ArrayAttr(
SmallVector<int64_t, 4>(inputType.getRank(), 1));
rewriter.replaceOpWithNewOp<SliceOp>(dynamicSlice, input, sliceStartIndices,
sliceLimits, sliceStrides);
return success();
}
};
//////////////////////////////////
// RealDynamicSliceOp
/////////////////////////////////
// Pattern: real_dynamic_slice(X, start, limit, strides)
// -> dynamic_slice(X, start, limit, strides)
// [if strides, start are constants, limit = start + constant]
struct RealDynamicSliceOpToDynamicSlice
: public SimplifyOpRewritePattern<RealDynamicSliceOp> {
using SimplifyOpRewritePattern<RealDynamicSliceOp>::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(RealDynamicSliceOp op,
PatternRewriter& rewriter) const override {
// This rewrite only works for unit strides because DynamicSliceOp
// doesn't support strides (i.e. it implicitly has unit strides).
DenseIntElementsAttr stridesAttr;
if (!matchPattern(op.getStrides(), m_Constant(&stridesAttr)))
return rewriter.notifyMatchFailure(op, "requires constant strides");
if (!llvm::all_of(stridesAttr.getValues<APInt>(),
[&](APInt stride) { return stride == 1; }))
return rewriter.notifyMatchFailure(op, "requires unit strides");
// Check that slice sizes are fully static (DynamicSliceOp style).
// To detect that, we check whether `limit_indices` is defined as
// `start_indices + constant` or `constant + start_indices`.
DenseIntElementsAttr sliceSizesAttr;
auto m_startIndices = matchers::m_Val(op.getStartIndices());
// Only handle the AddOp case, if all constant we fold to SliceOp.
if (!matchPattern(
op.getLimitIndices(),
m_Op<AddOp>(m_startIndices, m_Constant(&sliceSizesAttr))) &&
!matchPattern(op.getLimitIndices(),
m_Op<AddOp>(m_Constant(&sliceSizesAttr), m_startIndices)))
return rewriter.notifyMatchFailure(
op, "requires limit indices equal to start indices plus constant");
// RealDynamicSliceOp can take tensors of integer or index element types.
// DynamicSliceOp::slice_sizes only supports i64 element type.
// Adapt accordingly in order to be compatible with DynamicSliceOp.
SmallVector<int64_t> sliceSizes;
for (auto element : sliceSizesAttr.getValues<APInt>()) {
sliceSizes.push_back(element.getSExtValue());
}
// RealDynamicSliceOp::start_indices is a 1-dimensional tensor.
// DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors.
// Adapt accordingly in order to be compatible with DynamicSliceOp.
SmallVector<Value> startIndices;
for (auto i = 0; i < static_cast<int64_t>(sliceSizes.size()); ++i) {
auto startIndex1D =
SliceOp::create(rewriter, op.getLoc(), op.getStartIndices(),
rewriter.getDenseI64ArrayAttr(i),
rewriter.getDenseI64ArrayAttr(i + 1),
rewriter.getDenseI64ArrayAttr(1));
auto startIndex0DType = RankedTensorType::get(
{},
cast<ShapedType>(op.getStartIndices().getType()).getElementType());
auto startIndex0D = ReshapeOp::create(rewriter, op.getLoc(),
startIndex0DType, startIndex1D);
startIndices.push_back(startIndex0D);
}
rewriter.replaceOpWithNewOp<DynamicSliceOp>(
op, op.getOperand(), startIndices,
rewriter.getDenseI64ArrayAttr(sliceSizes));
return success();
}
};
//////////////////////////////////
// ReduceOp
/////////////////////////////////
// Pattern: reduce[A](_, _, fn:return A) -> A...
struct ReduceOpNoopVariableReturn final : SimplifyOpRewritePattern<ReduceOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter& rewriter) const override {
// If all returned values in the ReduceOp region exists outside the
// region, replace the ReduceOp with those values.
if (auto retOp = dyn_cast<ReturnOp>(op.getBody().front().getTerminator())) {
Region* retRegion = retOp->getParentRegion();
if (llvm::any_of(retOp.getResults(), [retRegion](Value result) {
return result.getParentRegion() == retRegion;
}))
return failure();
rewriter.replaceOp(op, retOp.getResults());
return success();
}
return failure();
}
};
// Pattern: reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...]
struct ReduceOpEmptyCanon final : SimplifyOpRewritePattern<ReduceOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter& rewriter) const override {
// We require all reduce shapes to be the same, up to the element types, so
// we can just use the first operand and the first result as
// representatives.
auto elemTy = cast<RankedTensorType>(op.getInputs().getType().front());
if (!llvm::is_contained(elemTy.getShape(), 0)) return failure();
Location loc = op.getLoc();
DenseI64ArrayAttr empty = rewriter.getDenseI64ArrayAttr({});
if (elemTy.hasStaticShape()) {
SmallVector<Value> broadcasts(op.getNumResults());
for (auto [bcast, init, outTy] : llvm::zip_equal(
broadcasts, op.getInitValues(), op.getResultTypes())) {
bcast = BroadcastInDimOp::create(rewriter, loc, outTy, init, empty);
}
rewriter.replaceOp(op, broadcasts);
return success();
}
SmallVector<Value> shapes;
if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), shapes)))
return failure();
SmallVector<Value> broadcasts(op.getNumResults());
for (auto [bcast, init, shape, outTy] : llvm::zip_equal(
broadcasts, op.getInitValues(), shapes, op.getResultTypes())) {
bcast = DynamicBroadcastInDimOp::create(rewriter, loc, outTy, init, shape,
empty);
}
rewriter.replaceOp(op, broadcasts);
return success();
}
};
// Pattern: reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)]
struct ReduceOpUnusedResultCanon final : SimplifyOpRewritePattern<ReduceOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter& rewriter) const override {
SmallVector<OpResult, 4> usedResults;
llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
[](OpResult result) { return !result.use_empty(); });
if (usedResults.size() == op.getNumResults())
return rewriter.notifyMatchFailure(op, "all operation results have uses");
const auto pairSize = 2;
const auto numOperands = op.getNumOperands();
const auto numOperandPairs = numOperands / pairSize;
Block& reducerBlock = op.getBody().front();
auto retOp = cast<ReturnOp>(reducerBlock.getTerminator());
assert(numOperandPairs == op.getNumResults() &&
numOperandPairs == retOp.getNumOperands());
SmallVector<Value> workList;
auto addToWorkList = [&workList,
reducerBody = retOp->getParentRegion()](Value v) {
if (v.getParentRegion() == reducerBody) workList.push_back(v);
};
SmallPtrSet<Operation*, 16> usedOps;
SmallBitVector usedArgs(numOperands);
SmallBitVector usedReturnOperands(numOperandPairs);
for (const auto& usedResult : usedResults) {
auto resultNo = usedResult.getResultNumber();
usedReturnOperands.set(resultNo);
// Follow the def-use chain starting from return operand to identify
// which argument pairs are used to compute it.
addToWorkList(retOp.getOperand(resultNo));
while (!workList.empty()) {
auto definition = workList.pop_back_val();
if (auto blockArg = dyn_cast<BlockArgument>(definition)) {
// using one argument implies using the whole argument pair
const auto pairNo = blockArg.getArgNumber() % numOperandPairs;
usedArgs.set(pairNo);
usedArgs.set(pairNo + numOperandPairs);
} else if (auto* defOp = definition.getDefiningOp()) {
usedOps.insert(defOp);
for (const auto& operand : defOp->getOperands())
addToWorkList(operand);
}
}
}
const auto newNumOperandPairs = usedResults.size();
const auto newNumOperands = newNumOperandPairs * pairSize;
if (newNumOperands != usedArgs.count()) {
return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) {
diag << "non-conservative case: " << newNumOperandPairs
<< " return results should be matched with " << newNumOperands
<< " operands, but got " << usedArgs.count();
});
}
SmallVector<Value> newInputs;
SmallVector<Value> newInitVals;
SmallVector<Type> newElementTypes;
for (auto i : llvm::seq(0u, numOperandPairs)) {
if (usedReturnOperands[i])
newElementTypes.push_back(
getElementTypeOrSelf(retOp.getOperand(i).getType()));
if (!usedArgs[i]) continue;
newInputs.push_back(op.getOperand(i));
newInitVals.push_back(op.getOperand(i + numOperandPairs));
}
auto newOp = ReduceOp::create(rewriter, op.getLoc(), newInputs, newInitVals,
op.getDimensionsAttr(), newElementTypes);
Block* newReducerBlock = rewriter.createBlock(&newOp.getBody());
IRMapping mapper;
for (auto arg : reducerBlock.getArguments())
if (usedArgs[arg.getArgNumber()])
mapper.map(arg,
newReducerBlock->addArgument(arg.getType(), arg.getLoc()));
rewriter.setInsertionPointToStart(newReducerBlock);
for (Operation& op : reducerBlock.getOperations())
if (usedOps.contains(&op)) rewriter.clone(op, mapper);
SmallVector<Value> newReturnOperands;
for (const auto& en : llvm::enumerate(retOp.getOperands()))
if (usedReturnOperands[en.index()])
newReturnOperands.push_back(mapper.lookup(en.value()));
ReturnOp::create(rewriter, retOp.getLoc(), newReturnOperands);
// Build new results list (unused entries will be null).
SmallVector<Value> newResults(op.getNumResults());
for (const auto& [i, result] : llvm::enumerate(usedResults)) {
newResults[result.getResultNumber()] = newOp.getResult(i);
}
rewriter.replaceOp(op, newResults);
return success();
}
};
/////////////////////////////////
// GetDimensionSizeOp
/////////////////////////////////
// TODO: This is duplicated with a pattern in shape refinement, consider
// consolidating.
// Pattern: get_dimension_size(X, i) -> X.shape[i]
struct GetDimensionSizeOpCanon final
: SimplifyOpRewritePattern<GetDimensionSizeOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(GetDimensionSizeOp op,
PatternRewriter& rewriter) const override {
// Fold get_dimension_size when the queried dim is statically known.
RankedTensorType operandTy = op.getOperand().getType();
int64_t dimSize = operandTy.getDimSize(op.getDimension());
if (dimSize < 0) return failure();
auto elemTy = cast<IntegerType>(op.getType().getElementType());
IntegerAttr elemVal = rewriter.getIntegerAttr(elemTy, dimSize);
rewriter.replaceOpWithNewOp<ConstantOp>(
op, DenseElementsAttr::get(op.getType(), elemVal));
return success();
}
};
//////////////////////////////////
// GatherOp
/////////////////////////////////
/// Converts gather ops to slice ops in case we have a single set of constant
/// indices.
// Pattern: gather(X, cst_start_indices) -> slice(X, slice_start, slice_end)
struct GatherOpCanon final : SimplifyOpRewritePattern<GatherOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter& rewriter) const override {
DenseIntElementsAttr index;
if (!matchPattern(gather.getStartIndices(), m_Constant(&index)))
return failure();
GatherDimensionNumbersAttr dnums = gather.getDimensionNumbers();
if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1)
return failure();
// TODO: Remove when the verifier catches this case that is
// invalid if all previous condition holds.
if (index.getNumElements() !=
static_cast<int64_t>(dnums.getStartIndexMap().size())) {
return failure();
}
auto operandType = cast<RankedTensorType>(gather->getOperand(0).getType());
if (!operandType.hasStaticShape()) return failure();
auto sliceEnd = llvm::to_vector(gather.getSliceSizes());
SmallVector<int64_t> sliceStart(sliceEnd.size(), 0);
for (auto [mapIndex, value] :
llvm::zip_equal(dnums.getStartIndexMap(), index.getValues<APInt>())) {
// Clamp the indices within bounds to faithfully mirror gather semantics.
int64_t offset =
std::clamp(value.getSExtValue(), static_cast<int64_t>(0),
operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]);
sliceStart[mapIndex] += offset;
sliceEnd[mapIndex] += offset;
}
SmallVector<int64_t> sliceStride(sliceEnd.size(), 1);
SmallVector<int64_t> sliceShape(sliceEnd.size());
for (auto [shapeElem, startElem, endElem] :
llvm::zip_equal(sliceShape, sliceStart, sliceEnd)) {
shapeElem = endElem - startElem;
}
Type elementType = gather.getType().getElementType();
auto sliceType = RankedTensorType::get(sliceShape, elementType);
Value result = SliceOp::create(rewriter, gather.getLoc(), sliceType,
gather.getOperand(),
rewriter.getDenseI64ArrayAttr(sliceStart),
rewriter.getDenseI64ArrayAttr(sliceEnd),
rewriter.getDenseI64ArrayAttr(sliceStride));
ArrayRef<int64_t> collapsedSliceDims = dnums.getCollapsedSliceDims();
if (!collapsedSliceDims.empty()) {
llvm::SmallVector<int64_t> reshapeShape;
for (auto [idx, dim] : llvm::enumerate(sliceShape)) {
if (!llvm::is_contained(collapsedSliceDims, idx))
reshapeShape.push_back(dim);
}
auto reshapeType = RankedTensorType::get(reshapeShape, elementType);
result =
ReshapeOp::create(rewriter, gather.getLoc(), reshapeType, result);
}
result.setType(gather.getType());
rewriter.replaceOp(gather, result);
return success();
}
};
//////////////////////////////////
// IotaOp
/////////////////////////////////