Skip to content

Commit e89f909

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic] infer-memref-layout is now defined via pass_boilerplate.h
There is little benefit in having pass definitions in the .td file. PiperOrigin-RevId: 834373569
1 parent 00d707e commit e89f909

File tree

2 files changed

+2
-19
lines changed

2 files changed

+2
-19
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,22 +1489,6 @@ def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mli
14891489
let options = [Option<"total_devices", "total-devices", "int", "", "">];
14901490
}
14911491

1492-
def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncOp"> {
1493-
let dependentDialects = [
1494-
"::mlir::func::FuncDialect",
1495-
"::mlir::memref::MemRefDialect",
1496-
];
1497-
let constructor = "::mlir::tpu::createInferMemRefLayoutPass()";
1498-
let options = [
1499-
// If hardware_generation is not set, the default value of -1 will crash on
1500-
// runOnOperation.
1501-
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
1502-
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
1503-
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
1504-
Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">,
1505-
];
1506-
}
1507-
15081492
def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::FuncOp"> {
15091493
let dependentDialects = [
15101494
"::mlir::arith::ArithDialect",

jaxlib/mosaic/dialect/tpu/tpu_dialect.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ struct ApplyVectorLayoutContext {
7070
std::pair<bool, bool> mightCommunicateBetweenChips(Operation *op);
7171

7272
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
73-
int hardware_generation = -1,
74-
std::array<int64_t, 2> target_shape = {8, 128},
75-
const TpuTilingFlags &tpu_tiling_flags = {});
73+
int hardware_generation, std::array<int64_t, 2> target_shape,
74+
const TpuTilingFlags& tpu_tiling_flags);
7675

7776
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
7877
int hardware_generation = -1, bool compatibility_mode = true,

0 commit comments

Comments
 (0)