Skip to content

Commit 53daa0c

Browse files
bythew3ijax authors
authored andcommitted
[XLA:Mosaic] Fix infer layout for nested loop.
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block. - Refactor the logic of assume layouts for block arguments to a helper function. - Add tests for nested fori loop and while loop. PiperOrigin-RevId: 641973011
1 parent f6ce973 commit 53daa0c

File tree

1 file changed

+95
-105
lines changed

1 file changed

+95
-105
lines changed

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 95 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ limitations under the License.
4848
#include "mlir/include/mlir/IR/Attributes.h"
4949
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
5050
#include "mlir/include/mlir/IR/OpDefinition.h"
51+
#include "mlir/include/mlir/IR/Visitors.h"
5152
#include "jaxlib/mosaic/dialect/tpu/layout.h"
5253
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
5354
#include "xla/layout.h"
@@ -122,10 +123,7 @@ class VectorLayoutInferer {
122123

123124
LogicalResult inferBlock(
124125
Block &block,
125-
const std::function<LogicalResult(Operation *)> &match_terminator,
126-
// TODO(jevinjiang): Propagate this flag deeper because it won't work when
127-
// there is an op with blocks inside this block.
128-
bool override_layout = false) {
126+
const std::function<LogicalResult(Operation *)> &match_terminator) {
129127
for (Operation &any_op : block.without_terminator()) {
130128
VLOG(kLayoutLog) << Print(&any_op);
131129
if (any_op.hasAttr("in_layout") || any_op.hasAttr("out_layout")) {
@@ -134,8 +132,6 @@ class VectorLayoutInferer {
134132
any_op.hasAttr("in_layout") && any_op.hasAttr("out_layout"),
135133
"expect layout attributes in tpu::AssumeLayoutOp");
136134
continue;
137-
} else if (override_layout) {
138-
// Intend to override the layouts attribute.
139135
} else {
140136
any_op.emitOpError("layout attributes already attached");
141137
return failure();
@@ -508,35 +504,15 @@ class VectorLayoutInferer {
508504
op->getNumOperands() == 3 + op.getNumResults(),
509505
"expected num_operands is equal to 3 + num_results in scf.for");
510506

511-
SmallVector<Layout, 4> in_layouts = getLayoutFromOperands(op);
512-
// Drop the first 3 layouts for lower bound, upper bound and step.
513-
ArrayRef<Layout> arg_layouts = ArrayRef<Layout>(in_layouts).drop_front(3);
514-
SmallVector<tpu::AssumeLayoutOp, 4> assume_layout_ops;
515-
assume_layout_ops.reserve(arg_layouts.size());
516-
// Use tpu.assume_layout to annotate every block argument with the layout of
517-
// the corresponding operand in forOp and replace all uses of the block
518-
// argument with the result of tpu.assume_layout.
519-
ImplicitLocOpBuilder builder =
520-
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBody());
521-
522-
// Drop the induction_variable and layouts of bounds+step (respectively).
523-
for (auto [iter_arg, layout] : llvm::zip_equal(
524-
op.getBody()->getArguments().drop_front(1), arg_layouts)) {
525-
if (!dyn_cast<VectorType>(iter_arg.getType())) {
526-
assume_layout_ops.push_back(nullptr);
527-
continue;
528-
}
529-
auto assume_layout_op =
530-
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
531-
setLayout(assume_layout_op, layout, layout);
532-
assume_layout_ops.push_back(assume_layout_op);
533-
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
534-
return operand.getOwner() != assume_layout_op;
535-
});
536-
}
537-
538-
if (inferBlock(*op.getBody(), match_yield).failed()) {
539-
return failure();
507+
auto in_layouts = getLayoutFromOperands(op);
508+
// Drop the input layouts for lower bound, upper bound. But keep the layout
509+
// for step because it matches with induction variable in arguments.
510+
auto arg_layouts = ArrayRef<Layout>(in_layouts).drop_front(2);
511+
if (assumeLayoutsForBlockArgs(*op.getBody(), arg_layouts).failed() ||
512+
inferBlock(*op.getBody(), match_yield).failed()) {
513+
return op.emitOpError(
514+
"failed to infer layout with initial layouts for body in "
515+
"scf.for op");
540516
}
541517
auto yield_op = op.getBody()->getTerminator();
542518
auto yield_in_layouts = getLayoutFromOperands(yield_op);
@@ -546,7 +522,8 @@ class VectorLayoutInferer {
546522
int out_idx = 0;
547523
bool require_reinfer = false;
548524
for (auto [in_layout, yield_layout, result] :
549-
llvm::zip_equal(ArrayRef<Layout>(in_layouts).drop_front(3),
525+
llvm::zip_equal(arg_layouts.drop_front(
526+
1), // Drop the layout for induction variable.
550527
yield_in_layouts, op.getResults())) {
551528
if (auto vty = dyn_cast<VectorType>(result.getType())) {
552529
if (!in_layout.has_value()) {
@@ -586,24 +563,25 @@ class VectorLayoutInferer {
586563
++out_idx;
587564
}
588565
if (require_reinfer) {
566+
// Force same layouts in input layout but skip the first 3 layouts for
567+
// lower bound, upper bound and step.
568+
std::copy(out_layouts.begin(), out_layouts.end(), in_layouts.begin() + 3);
569+
589570
// Terminator in the loop will carry layouts to the next loop but
590571
// the loop's block args' layouts are determined by the initial inputs. We
591572
// need to force the same layouts for all in order to make layouts be
592573
// consistent across all branches. To ensure that, we need to reprocess
593574
// layout inference for the entire body with the final consolidated
594575
// layout.
595-
for (int64_t i = 0; i < out_layouts.size(); ++i) {
596-
if (assume_layout_ops[i]) {
597-
setLayout(assume_layout_ops[i], out_layouts[i], out_layouts[i]);
598-
}
599-
}
600-
if (inferBlock(*op.getBody(), match_yield, /*override_layout=*/true)
601-
.failed()) {
602-
return op.emitOpError("failed to infer layout for scf.for op");
576+
clearBlockLayouts(*op.getBody());
577+
if (assumeLayoutsForBlockArgs(*op.getBody(),
578+
ArrayRef<Layout>(in_layouts).drop_front(2))
579+
.failed() ||
580+
inferBlock(*op.getBody(), match_yield).failed()) {
581+
return op.emitOpError(
582+
"failed to infer layout with compatible layouts for body in "
583+
"scf.for op");
603584
}
604-
std::copy(out_layouts.begin(), out_layouts.end(),
605-
in_layouts.begin() + 3); // Skip first 3 layouts for lower
606-
// bound, upper bound and step.
607585
}
608586
setInLayout(yield_op, out_layouts);
609587
setLayout(op, in_layouts, out_layouts);
@@ -622,53 +600,19 @@ class VectorLayoutInferer {
622600
TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while");
623601

624602
SmallVector<Layout, 4> in_layouts = getLayoutFromOperands(op);
625-
SmallVector<tpu::AssumeLayoutOp, 4> before_assume_layout_ops;
626-
before_assume_layout_ops.reserve(in_layouts.size());
627-
SmallVector<tpu::AssumeLayoutOp, 4> after_assume_layout_ops;
628-
after_assume_layout_ops.reserve(in_layouts.size());
629603

630-
// Use tpu.assume_layout to annotate every block argument with the layout of
631-
// the corresponding operand in WhileOp and replace all uses of the block
632-
// argument with the result of tpu.assume_layout.
633-
ImplicitLocOpBuilder builder =
634-
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBeforeBody());
635-
for (auto [iter_arg, layout] :
636-
llvm::zip_equal(op.getBeforeBody()->getArguments(), in_layouts)) {
637-
if (!dyn_cast<VectorType>(iter_arg.getType())) {
638-
before_assume_layout_ops.push_back(nullptr);
639-
continue;
640-
}
641-
auto assume_layout_op =
642-
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
643-
setLayout(assume_layout_op, layout, layout);
644-
before_assume_layout_ops.push_back(assume_layout_op);
645-
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
646-
return operand.getOwner() != assume_layout_op;
647-
});
648-
}
649-
if (inferBlock(*op.getBeforeBody(), match_condition).failed()) {
650-
return failure();
604+
if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), in_layouts).failed() ||
605+
inferBlock(*op.getBeforeBody(), match_condition).failed()) {
606+
return op.emitOpError(
607+
"failed to infer layout with initial layouts for before body in "
608+
"scf.while op");
651609
}
652610

653-
builder =
654-
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getAfterBody());
655-
for (auto [iter_arg, layout] :
656-
llvm::zip_equal(op.getAfterBody()->getArguments(), in_layouts)) {
657-
if (!dyn_cast<VectorType>(iter_arg.getType())) {
658-
after_assume_layout_ops.push_back(nullptr);
659-
continue;
660-
}
661-
auto assume_layout_op =
662-
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
663-
setLayout(assume_layout_op, layout, layout);
664-
after_assume_layout_ops.push_back(assume_layout_op);
665-
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
666-
return operand.getOwner() != assume_layout_op;
667-
});
668-
}
669-
670-
if (inferBlock(*op.getAfterBody(), match_yield).failed()) {
671-
return failure();
611+
if (assumeLayoutsForBlockArgs(*op.getAfterBody(), in_layouts).failed() ||
612+
inferBlock(*op.getAfterBody(), match_yield).failed()) {
613+
return op.emitOpError(
614+
"failed to infer layout with initial layouts for after body in "
615+
"scf.while op");
672616
}
673617

674618
auto *cond_op = op.getBeforeBody()->getTerminator();
@@ -738,27 +682,26 @@ class VectorLayoutInferer {
738682
++out_idx;
739683
}
740684
if (require_reinfer) {
685+
clearBlockLayouts(*op.getBeforeBody());
686+
clearBlockLayouts(*op.getAfterBody());
741687
// Terminator in the loop will carry layouts to the next loop but
742688
// the loop's block args' layouts are determined by the initial inputs. We
743689
// need to force the same layouts for all in order to make layouts be
744690
// consistent across all branches. To ensure that, we need to reprocess
745691
// layout inference for the entire body with the final consolidated
746692
// layout.
747-
for (int64_t i = 0; i < out_layouts.size(); ++i) {
748-
if (before_assume_layout_ops[i]) {
749-
setLayout(before_assume_layout_ops[i], out_layouts[i],
750-
out_layouts[i]);
751-
}
752-
if (after_assume_layout_ops[i]) {
753-
setLayout(after_assume_layout_ops[i], out_layouts[i], out_layouts[i]);
754-
}
755-
}
756-
if (inferBlock(*op.getBeforeBody(), match_condition,
757-
/*override_layout=*/true)
693+
if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), out_layouts)
758694
.failed() ||
759-
inferBlock(*op.getAfterBody(), match_yield, /*override_layout=*/true)
760-
.failed()) {
761-
return op.emitOpError("failed to infer layout for scf.while op");
695+
inferBlock(*op.getBeforeBody(), match_condition).failed()) {
696+
return op.emitOpError(
697+
"failed to infer layout with compatible layouts for before body in "
698+
"scf.while op");
699+
}
700+
if (assumeLayoutsForBlockArgs(*op.getAfterBody(), out_layouts).failed() ||
701+
inferBlock(*op.getAfterBody(), match_yield).failed()) {
702+
return op.emitOpError(
703+
"failed to infer layout with compatible layouts for after body in "
704+
"scf.while op");
762705
}
763706
}
764707
std::copy(out_layouts.begin(), out_layouts.end(),
@@ -1854,6 +1797,53 @@ class VectorLayoutInferer {
18541797
return true;
18551798
}
18561799

1800+
LogicalResult assumeLayoutsForBlockArgs(Block &block,
1801+
ArrayRef<Layout> layouts) {
1802+
auto op = block.getParentOp();
1803+
if (layouts.size() != block.getNumArguments()) {
1804+
return op->emitOpError(
1805+
"Block arguments must have the same number of layouts");
1806+
}
1807+
// Use tpu.assume_layout to annotate every block argument with the layout of
1808+
// the corresponding operand and replace all uses of the block argument with
1809+
// the result of tpu.assume_layout.
1810+
ImplicitLocOpBuilder builder =
1811+
ImplicitLocOpBuilder::atBlockBegin(op->getLoc(), &block);
1812+
for (auto [iter_arg, layout] :
1813+
llvm::zip_equal(block.getArguments(), layouts)) {
1814+
if (!dyn_cast<VectorType>(iter_arg.getType())) {
1815+
continue;
1816+
}
1817+
if (llvm::any_of(iter_arg.getUsers(), [](Operation *user) {
1818+
return isa<tpu::AssumeLayoutOp>(user);
1819+
})) {
1820+
return op->emitOpError("Expected no assume layout for block arguments");
1821+
}
1822+
auto assume_layout_op =
1823+
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
1824+
setLayout(assume_layout_op, layout, layout);
1825+
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
1826+
return operand.getOwner() != assume_layout_op;
1827+
});
1828+
}
1829+
return success();
1830+
}
1831+
1832+
void clearBlockLayouts(Block &block) {
1833+
block.walk([&](Operation *op) {
1834+
// We need to remove assume_layout ops in each block. Otherwise, we will
1835+
// create extra assume_layout ops for nested blocks.
1836+
if (auto assume_op = dyn_cast<tpu::AssumeLayoutOp>(op)) {
1837+
assume_op.getResult().replaceAllUsesWith(assume_op.getInput());
1838+
assume_op->erase();
1839+
return WalkResult::advance();
1840+
}
1841+
op->removeAttr("in_layout");
1842+
op->removeAttr("out_layout");
1843+
return WalkResult::advance();
1844+
});
1845+
}
1846+
18571847
void setInLayout(Operation *op, ArrayRef<Layout> in) {
18581848
CHECK_EQ(in.size(), op->getNumOperands()) << Print(op);
18591849
SmallVector<Attribute, 4> in_attrs;

0 commit comments

Comments
 (0)