@@ -48,6 +48,7 @@ limitations under the License.
48
48
#include " mlir/include/mlir/IR/Attributes.h"
49
49
#include " mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
50
50
#include " mlir/include/mlir/IR/OpDefinition.h"
51
+ #include " mlir/include/mlir/IR/Visitors.h"
51
52
#include " jaxlib/mosaic/dialect/tpu/layout.h"
52
53
#include " jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
53
54
#include " xla/layout.h"
@@ -122,10 +123,7 @@ class VectorLayoutInferer {
122
123
123
124
LogicalResult inferBlock (
124
125
Block &block,
125
- const std::function<LogicalResult(Operation *)> &match_terminator,
126
- // TODO(jevinjiang): Propagate this flag deeper because it won't work when
127
- // there is an op with blocks inside this block.
128
- bool override_layout = false) {
126
+ const std::function<LogicalResult(Operation *)> &match_terminator) {
129
127
for (Operation &any_op : block.without_terminator ()) {
130
128
VLOG (kLayoutLog ) << Print (&any_op);
131
129
if (any_op.hasAttr (" in_layout" ) || any_op.hasAttr (" out_layout" )) {
@@ -134,8 +132,6 @@ class VectorLayoutInferer {
134
132
any_op.hasAttr (" in_layout" ) && any_op.hasAttr (" out_layout" ),
135
133
" expect layout attributes in tpu::AssumeLayoutOp" );
136
134
continue ;
137
- } else if (override_layout) {
138
- // Intend to override the layouts attribute.
139
135
} else {
140
136
any_op.emitOpError (" layout attributes already attached" );
141
137
return failure ();
@@ -508,35 +504,15 @@ class VectorLayoutInferer {
508
504
op->getNumOperands () == 3 + op.getNumResults (),
509
505
" expected num_operands is equal to 3 + num_results in scf.for" );
510
506
511
- SmallVector<Layout, 4 > in_layouts = getLayoutFromOperands (op);
512
- // Drop the first 3 layouts for lower bound, upper bound and step.
513
- ArrayRef<Layout> arg_layouts = ArrayRef<Layout>(in_layouts).drop_front (3 );
514
- SmallVector<tpu::AssumeLayoutOp, 4 > assume_layout_ops;
515
- assume_layout_ops.reserve (arg_layouts.size ());
516
- // Use tpu.assume_layout to annotate every block argument with the layout of
517
- // the corresponding operand in forOp and replace all uses of the block
518
- // argument with the result of tpu.assume_layout.
519
- ImplicitLocOpBuilder builder =
520
- ImplicitLocOpBuilder::atBlockBegin (op.getLoc (), op.getBody ());
521
-
522
- // Drop the induction_variable and layouts of bounds+step (respectively).
523
- for (auto [iter_arg, layout] : llvm::zip_equal (
524
- op.getBody ()->getArguments ().drop_front (1 ), arg_layouts)) {
525
- if (!dyn_cast<VectorType>(iter_arg.getType ())) {
526
- assume_layout_ops.push_back (nullptr );
527
- continue ;
528
- }
529
- auto assume_layout_op =
530
- builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
531
- setLayout (assume_layout_op, layout, layout);
532
- assume_layout_ops.push_back (assume_layout_op);
533
- iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
534
- return operand.getOwner () != assume_layout_op;
535
- });
536
- }
537
-
538
- if (inferBlock (*op.getBody (), match_yield).failed ()) {
539
- return failure ();
507
+ auto in_layouts = getLayoutFromOperands (op);
508
+ // Drop the input layouts for lower bound, upper bound. But keep the layout
509
+ // for step because it matches with induction variable in arguments.
510
+ auto arg_layouts = ArrayRef<Layout>(in_layouts).drop_front (2 );
511
+ if (assumeLayoutsForBlockArgs (*op.getBody (), arg_layouts).failed () ||
512
+ inferBlock (*op.getBody (), match_yield).failed ()) {
513
+ return op.emitOpError (
514
+ " failed to infer layout with initial layouts for body in "
515
+ " scf.for op" );
540
516
}
541
517
auto yield_op = op.getBody ()->getTerminator ();
542
518
auto yield_in_layouts = getLayoutFromOperands (yield_op);
@@ -546,7 +522,8 @@ class VectorLayoutInferer {
546
522
int out_idx = 0 ;
547
523
bool require_reinfer = false ;
548
524
for (auto [in_layout, yield_layout, result] :
549
- llvm::zip_equal (ArrayRef<Layout>(in_layouts).drop_front (3 ),
525
+ llvm::zip_equal (arg_layouts.drop_front (
526
+ 1 ), // Drop the layout for induction variable.
550
527
yield_in_layouts, op.getResults ())) {
551
528
if (auto vty = dyn_cast<VectorType>(result.getType ())) {
552
529
if (!in_layout.has_value ()) {
@@ -586,24 +563,25 @@ class VectorLayoutInferer {
586
563
++out_idx;
587
564
}
588
565
if (require_reinfer) {
566
+ // Force same layouts in input layout but skip the first 3 layouts for
567
+ // lower bound, upper bound and step.
568
+ std::copy (out_layouts.begin (), out_layouts.end (), in_layouts.begin () + 3 );
569
+
589
570
// Terminator in the loop will carry layouts to the next loop but
590
571
// the loop's block args' layouts are determined by the initial inputs. We
591
572
// need to force the same layouts for all in order to make layouts be
592
573
// consistent across all branches. To ensure that, we need to reprocess
593
574
// layout inference for the entire body with the final consolidated
594
575
// layout.
595
- for ( int64_t i = 0 ; i < out_layouts. size (); ++i) {
596
- if (assume_layout_ops[i]) {
597
- setLayout (assume_layout_ops[i], out_layouts[i], out_layouts[i]);
598
- }
599
- }
600
- if ( inferBlock (* op.getBody (), match_yield, /* override_layout= */ true )
601
- . failed ()) {
602
- return op. emitOpError ( " failed to infer layout for scf.for op" );
576
+ clearBlockLayouts (*op. getBody ());
577
+ if (assumeLayoutsForBlockArgs (*op. getBody (),
578
+ ArrayRef<Layout>(in_layouts). drop_front ( 2 ))
579
+ . failed () ||
580
+ inferBlock (*op. getBody (), match_yield). failed ()) {
581
+ return op.emitOpError (
582
+ " failed to infer layout with compatible layouts for body in "
583
+ " scf.for op" );
603
584
}
604
- std::copy (out_layouts.begin (), out_layouts.end (),
605
- in_layouts.begin () + 3 ); // Skip first 3 layouts for lower
606
- // bound, upper bound and step.
607
585
}
608
586
setInLayout (yield_op, out_layouts);
609
587
setLayout (op, in_layouts, out_layouts);
@@ -622,53 +600,19 @@ class VectorLayoutInferer {
622
600
TPU_CHECK_OP (op.getNumRegions () == 2 , " expected two blocks for scf.while" );
623
601
624
602
SmallVector<Layout, 4 > in_layouts = getLayoutFromOperands (op);
625
- SmallVector<tpu::AssumeLayoutOp, 4 > before_assume_layout_ops;
626
- before_assume_layout_ops.reserve (in_layouts.size ());
627
- SmallVector<tpu::AssumeLayoutOp, 4 > after_assume_layout_ops;
628
- after_assume_layout_ops.reserve (in_layouts.size ());
629
603
630
- // Use tpu.assume_layout to annotate every block argument with the layout of
631
- // the corresponding operand in WhileOp and replace all uses of the block
632
- // argument with the result of tpu.assume_layout.
633
- ImplicitLocOpBuilder builder =
634
- ImplicitLocOpBuilder::atBlockBegin (op.getLoc (), op.getBeforeBody ());
635
- for (auto [iter_arg, layout] :
636
- llvm::zip_equal (op.getBeforeBody ()->getArguments (), in_layouts)) {
637
- if (!dyn_cast<VectorType>(iter_arg.getType ())) {
638
- before_assume_layout_ops.push_back (nullptr );
639
- continue ;
640
- }
641
- auto assume_layout_op =
642
- builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
643
- setLayout (assume_layout_op, layout, layout);
644
- before_assume_layout_ops.push_back (assume_layout_op);
645
- iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
646
- return operand.getOwner () != assume_layout_op;
647
- });
648
- }
649
- if (inferBlock (*op.getBeforeBody (), match_condition).failed ()) {
650
- return failure ();
604
+ if (assumeLayoutsForBlockArgs (*op.getBeforeBody (), in_layouts).failed () ||
605
+ inferBlock (*op.getBeforeBody (), match_condition).failed ()) {
606
+ return op.emitOpError (
607
+ " failed to infer layout with initial layouts for before body in "
608
+ " scf.while op" );
651
609
}
652
610
653
- builder =
654
- ImplicitLocOpBuilder::atBlockBegin (op.getLoc (), op.getAfterBody ());
655
- for (auto [iter_arg, layout] :
656
- llvm::zip_equal (op.getAfterBody ()->getArguments (), in_layouts)) {
657
- if (!dyn_cast<VectorType>(iter_arg.getType ())) {
658
- after_assume_layout_ops.push_back (nullptr );
659
- continue ;
660
- }
661
- auto assume_layout_op =
662
- builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
663
- setLayout (assume_layout_op, layout, layout);
664
- after_assume_layout_ops.push_back (assume_layout_op);
665
- iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
666
- return operand.getOwner () != assume_layout_op;
667
- });
668
- }
669
-
670
- if (inferBlock (*op.getAfterBody (), match_yield).failed ()) {
671
- return failure ();
611
+ if (assumeLayoutsForBlockArgs (*op.getAfterBody (), in_layouts).failed () ||
612
+ inferBlock (*op.getAfterBody (), match_yield).failed ()) {
613
+ return op.emitOpError (
614
+ " failed to infer layout with initial layouts for after body in "
615
+ " scf.while op" );
672
616
}
673
617
674
618
auto *cond_op = op.getBeforeBody ()->getTerminator ();
@@ -738,27 +682,26 @@ class VectorLayoutInferer {
738
682
++out_idx;
739
683
}
740
684
if (require_reinfer) {
685
+ clearBlockLayouts (*op.getBeforeBody ());
686
+ clearBlockLayouts (*op.getAfterBody ());
741
687
// Terminator in the loop will carry layouts to the next loop but
742
688
// the loop's block args' layouts are determined by the initial inputs. We
743
689
// need to force the same layouts for all in order to make layouts be
744
690
// consistent across all branches. To ensure that, we need to reprocess
745
691
// layout inference for the entire body with the final consolidated
746
692
// layout.
747
- for (int64_t i = 0 ; i < out_layouts.size (); ++i) {
748
- if (before_assume_layout_ops[i]) {
749
- setLayout (before_assume_layout_ops[i], out_layouts[i],
750
- out_layouts[i]);
751
- }
752
- if (after_assume_layout_ops[i]) {
753
- setLayout (after_assume_layout_ops[i], out_layouts[i], out_layouts[i]);
754
- }
755
- }
756
- if (inferBlock (*op.getBeforeBody (), match_condition,
757
- /* override_layout=*/ true )
693
+ if (assumeLayoutsForBlockArgs (*op.getBeforeBody (), out_layouts)
758
694
.failed () ||
759
- inferBlock (*op.getAfterBody (), match_yield, /* override_layout=*/ true )
760
- .failed ()) {
761
- return op.emitOpError (" failed to infer layout for scf.while op" );
695
+ inferBlock (*op.getBeforeBody (), match_condition).failed ()) {
696
+ return op.emitOpError (
697
+ " failed to infer layout with compatible layouts for before body in "
698
+ " scf.while op" );
699
+ }
700
+ if (assumeLayoutsForBlockArgs (*op.getAfterBody (), out_layouts).failed () ||
701
+ inferBlock (*op.getAfterBody (), match_yield).failed ()) {
702
+ return op.emitOpError (
703
+ " failed to infer layout with compatible layouts for after body in "
704
+ " scf.while op" );
762
705
}
763
706
}
764
707
std::copy (out_layouts.begin (), out_layouts.end (),
@@ -1854,6 +1797,53 @@ class VectorLayoutInferer {
1854
1797
return true ;
1855
1798
}
1856
1799
1800
+ LogicalResult assumeLayoutsForBlockArgs (Block &block,
1801
+ ArrayRef<Layout> layouts) {
1802
+ auto op = block.getParentOp ();
1803
+ if (layouts.size () != block.getNumArguments ()) {
1804
+ return op->emitOpError (
1805
+ " Block arguments must have the same number of layouts" );
1806
+ }
1807
+ // Use tpu.assume_layout to annotate every block argument with the layout of
1808
+ // the corresponding operand and replace all uses of the block argument with
1809
+ // the result of tpu.assume_layout.
1810
+ ImplicitLocOpBuilder builder =
1811
+ ImplicitLocOpBuilder::atBlockBegin (op->getLoc (), &block);
1812
+ for (auto [iter_arg, layout] :
1813
+ llvm::zip_equal (block.getArguments (), layouts)) {
1814
+ if (!dyn_cast<VectorType>(iter_arg.getType ())) {
1815
+ continue ;
1816
+ }
1817
+ if (llvm::any_of (iter_arg.getUsers (), [](Operation *user) {
1818
+ return isa<tpu::AssumeLayoutOp>(user);
1819
+ })) {
1820
+ return op->emitOpError (" Expected no assume layout for block arguments" );
1821
+ }
1822
+ auto assume_layout_op =
1823
+ builder.create <AssumeLayoutOp>(iter_arg.getType (), iter_arg);
1824
+ setLayout (assume_layout_op, layout, layout);
1825
+ iter_arg.replaceUsesWithIf (assume_layout_op, [&](OpOperand &operand) {
1826
+ return operand.getOwner () != assume_layout_op;
1827
+ });
1828
+ }
1829
+ return success ();
1830
+ }
1831
+
1832
+ void clearBlockLayouts (Block &block) {
1833
+ block.walk ([&](Operation *op) {
1834
+ // We need to remove assume_layout ops in each block. Otherwise, we will
1835
+ // create extra assume_layout ops for nested blocks.
1836
+ if (auto assume_op = dyn_cast<tpu::AssumeLayoutOp>(op)) {
1837
+ assume_op.getResult ().replaceAllUsesWith (assume_op.getInput ());
1838
+ assume_op->erase ();
1839
+ return WalkResult::advance ();
1840
+ }
1841
+ op->removeAttr (" in_layout" );
1842
+ op->removeAttr (" out_layout" );
1843
+ return WalkResult::advance ();
1844
+ });
1845
+ }
1846
+
1857
1847
void setInLayout (Operation *op, ArrayRef<Layout> in) {
1858
1848
CHECK_EQ (in.size (), op->getNumOperands ()) << Print (op);
1859
1849
SmallVector<Attribute, 4 > in_attrs;
0 commit comments