diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index b890f7fbc7b2..2efe19beb15f 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -1470,15 +1470,6 @@ def TPU_LogBufferOp : TPU_Op<"log_buffer"> { -def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createLogicalToPhysicalDeviceIdPass(-1)"; - let options = [Option<"total_devices", "total-devices", "int", "", "">]; -} def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::FuncOp"> { let dependentDialects = [ @@ -1555,18 +1546,5 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO ]; } -def PreCanonicalizationOptimizationPass : Pass<"pre-canonicalization-optimization", "::mlir::func::FuncOp"> { - let summary = "Fold matmul rhs tranpose into the op before layout inference"; - let constructor = "::mlir::tpu::createPreCanonicalizationOptimizationPass()"; - let dependentDialects = [ - "::mlir::vector::VectorDialect", - "::mlir::tpu::TPUDialect", - ]; - let options = [ - Option<"hardware_generation", "hardware-generation", "int", /*default=*/"6", "">, - Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, - Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, - ]; -} #endif // TPU_ATTRS diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 9ba415cef3f0..3efeada3c7bd 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -90,14 +90,6 @@ std::unique_ptr> createRelayoutInsertionPass( std::unique_ptr> createApplyVectorLayoutPass( const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{}); -std::unique_ptr> -createPreCanonicalizationOptimizationPass( - int hardware_generation = -1, - std::array target_shape = {8, 128}); - -std::unique_ptr> -createLogicalToPhysicalDeviceIdPass(int64_t total_devices); - #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"