diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 6521bea5f18a..2fdd4bc7c022 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -1529,6 +1529,11 @@ def PreCanonicalizationOptimizationPass : Pass<"pre-canonicalization-optimizatio "::mlir::vector::VectorDialect", "::mlir::tpu::TPUDialect", ]; + let options = [ + Option<"hardware_generation", "hardware-generation", "int", /*default=*/"6", "">, + Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, + Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, + ]; } #endif // TPU_ATTRS diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index b4d5357bfc55..28e5dde37c7c 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -91,7 +91,9 @@ std::unique_ptr> createApplyVectorLayoutPass( const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{}); std::unique_ptr> -createPreCanonicalizationOptimizationPass(); +createPreCanonicalizationOptimizationPass( + int hardware_generation = -1, + std::array target_shape = {8, 128}); std::unique_ptr> createLogicalToPhysicalDeviceIdPass(int64_t total_devices); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 99308c3f9661..25e2205a0844 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1424,66 +1424,6 @@ FailureOr canonicalize_vector_transpose(const CanonicalizeContext &ctx, return new_op; } -// Finds the split point for a reshape between a multi-dimensional shape and a -// shape where a suffix has been collapsed into a single dimension. -// -// This function checks if `src_shape` and `tgt_shape` follow the pattern: -// src_shape: (P..., S_1, S_2, ..., S_N) -// tgt_shape: (P..., T_collapsed) -// where `P` is a common prefix and `product(S_1..S_N) == T_collapsed`. -// -// It handles a differing number of leading 1s in the prefix by stripping them -// from both shapes before comparison. -// -// This utility is used for two inverse patterns: -// 1. Collapse (e.g., `load` -> `reshape`): The function is called directly, -// where `src_shape` is the multi-dimensional pre-reshape vector shape. -// 2. Expand (e.g., `reshape` -> `store`): The function is called with swapped -// arguments, where `src_shape` is the multi-dimensional *post-reshape* -// vector shape. -// -// Returns: -// - A pair containing: -// 1. The index in `src_shape` where the collapsing suffix begins. -// 2. The product of the collapsed dimensions excluding the innermost one -// (i.e., product(S_1..S_{N-1})), used as the "sublane product". -// - `std::nullopt` if the shapes do not match the pattern. -std::optional> findSplitPoint(ArrayRef src_shape, - ArrayRef tgt_shape) { - int s = 0, t = 0; - // drop leading 1s - while (s < src_shape.size() && src_shape[s] == 1) { - ++s; - } - while (t < tgt_shape.size() && tgt_shape[t] == 1) { - ++t; - } - - // Find the end of the common prefix between the shapes (ignoring leading 1s). - int s_prefix_end = s, t_prefix_end = t; - while (s_prefix_end < src_shape.size() && t_prefix_end < tgt_shape.size() && - src_shape[s_prefix_end] == tgt_shape[t_prefix_end]) { - ++s_prefix_end; - ++t_prefix_end; - } - - // After the common prefix, the rest of the target shape must consist of just - // one dimension (the collapsed one). - if (t_prefix_end != tgt_shape.size() - 1) { - return std::nullopt; - } - int64_t src_prod = 1; - for (int i = s_prefix_end; i < src_shape.size(); ++i) { - src_prod *= src_shape[i]; - } - - if (tgt_shape.back() != src_prod) { - return std::nullopt; - } - src_prod /= src_shape.back(); - return std::make_pair(s_prefix_end, src_prod); -} - FailureOr canonicalize_shape_cast(const CanonicalizeContext& ctx, Operation& raw_op) { CanonicalBuilder builder(ctx, raw_op.getLoc(), &raw_op); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc b/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc index 1023f21ae0fb..5d2cd9c11b9e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc @@ -14,13 +14,19 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include #include #include +#include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Attributes.h" @@ -29,9 +35,11 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/util.h" namespace mlir::tpu { @@ -41,6 +49,184 @@ namespace mlir::tpu { namespace { +void CanonicalizeStore(int hardware_generation, + std::array target_shape, Operation& raw_op) { + // Fuses a vector.shape_cast (that expands dimensions) into a subsequent + // vector.store or dense tpu.vector_store. This is the inverse of the + // canonicalize_reshape func. + Value value_to_store; + TypedValue base; + ValueRange indices; + + Operation* store_op; + + if (auto store = dyn_cast(raw_op)) { + store_op = store.getOperation(); + value_to_store = store.getValueToStore(); + base = store.getBase(); + indices = store.getIndices(); + } else if (auto store = dyn_cast(raw_op)) { + store_op = store.getOperation(); + value_to_store = store.getValueToStore(); + base = store.getBase(); + indices = store.getIndices(); + if (!store.getStrides().empty() || store.getMask() || store.getAdd()) { + return; + } + } else { + return; + } + + // Look for vector::ShapeCastOp feeding the store + auto shape_cast_op = + dyn_cast_if_present(value_to_store.getDefiningOp()); + if (!shape_cast_op || !shape_cast_op.getResult().hasOneUse()) { + return; + } + + auto src_ty = shape_cast_op.getSource().getType(); + auto tgt_ty = shape_cast_op.getResult().getType(); + auto memref_ty = base.getType(); + + if (tgt_ty.getShape() != memref_ty.getShape()) { + return; + } + if (!isContiguousMemref(base)) { + return; + } + if (src_ty.getRank() > tgt_ty.getRank()) { + return; + } + auto last_src_lanes = src_ty.getShape().back(); + if (last_src_lanes % target_shape[1] != 0) { + return; + } + std::optional> split_opt = + findSplitPoint(tgt_ty.getShape(), src_ty.getShape()); + if (!split_opt) { + return; + } + auto [split_point, sublane_prod] = *split_opt; + + int64_t bitwidth = src_ty.getElementTypeBitWidth(); + int64_t packing = 32 / bitwidth; + if (hardware_generation < 4 && packing > 1) { + return; + } + if (sublane_prod % packing != 0) { + return; + } + + ImplicitLocOpBuilder b(store_op->getLoc(), store_op); + auto loc = store_op->getLoc(); + auto i32_type = b.getI32Type(); + int64_t num_i32_rows = sublane_prod / packing; + + SmallVector mem_shape; + if (split_point == 0) { + mem_shape.push_back(sublane_prod); + } else { + mem_shape.assign(memref_ty.getShape().begin(), + memref_ty.getShape().begin() + split_point); + int64_t prev_dim = mem_shape.back(); + int64_t new_dim = prev_dim * sublane_prod; + if (sublane_prod != 0 && new_dim / sublane_prod != prev_dim) { + return; + } + mem_shape.back() = new_dim; + } + + auto lane_dim = memref_ty.getShape().back(); + if (lane_dim != target_shape[1]) { + return; + } + mem_shape.push_back(lane_dim); + Value reshaped_ref = b.create( + MemRefType::get(mem_shape, memref_ty.getElementType()), base); + + *(mem_shape.end() - 2) /= packing; + Value i32_view = b.create( + MemRefType::get(mem_shape, i32_type), reshaped_ref); + + Value src_vec = shape_cast_op.getSource(); + SmallVector slice_sizes(src_ty.getShape()); + slice_sizes.back() = lane_dim; + SmallVector unit_strides(src_ty.getRank(), 1); + + auto i32_view_shape = cast(i32_view.getType()).getShape(); + + SmallVector store_indices; + Value split_base_idx; + int64_t stride_dim; + + if (split_point == 0) { + // No common prefix - create indices for entire i32_view shape + split_base_idx = IdxConst(0, b, loc); + for (size_t i = 0; i < i32_view_shape.size(); ++i) { + store_indices.push_back(IdxConst(0, b, loc)); + } + stride_dim = 0; + } else { + // Common prefix exists - use it + store_indices.assign(indices.begin(), indices.begin() + split_point); + split_base_idx = store_indices.back(); + // Add remaining indices to match i32_view rank + while (store_indices.size() < i32_view_shape.size()) { + store_indices.push_back(IdxConst(0, b, loc)); + } + stride_dim = split_point - 1; + } + SmallVector strides(i32_view_shape.size(), 1); + strides[stride_dim] = num_i32_rows; + for (int64_t i = 0; i < num_i32_rows; ++i) { + SmallVector offsets(src_ty.getRank(), 0); + offsets.back() = i * packing * lane_dim; + Value slice = b.create( + src_vec, offsets, slice_sizes, unit_strides); + + auto i_chunk_ty = + VectorType::get(cast(slice.getType()).getShape(), + b.getIntegerType(bitwidth)); + auto i32_chunk_ty = + VectorType::get(cast(slice.getType()).getShape(), i32_type); + Value packed_chunk; + if (packing > 1) { + Value acc = b.create( + i32_chunk_ty, b.create(i_chunk_ty, slice)); + for (int64_t p = 1; p < packing; ++p) { + offsets.back() = (i * packing + p) * lane_dim; + slice = b.create( + src_vec, offsets, slice_sizes, unit_strides); + Value sj_i32 = b.create( + i32_chunk_ty, b.create(i_chunk_ty, slice)); + Value sh = I32Const(p * bitwidth, i32_chunk_ty.getShape(), b, loc); + acc = b.create(acc, b.create(sj_i32, sh)); + } + packed_chunk = acc; + } else { + packed_chunk = b.create(i32_chunk_ty, slice); + } + + auto packed_shape = cast(packed_chunk.getType()).getShape(); + Value chunk_to_store = packed_chunk; + if (i32_view_shape.size() > packed_shape.size()) { + SmallVector reshape_vec_shape( + i32_view_shape.size() - packed_shape.size(), 1); + reshape_vec_shape.append(packed_shape.begin(), packed_shape.end()); + auto reshape_type = VectorType::get(reshape_vec_shape, i32_type); + chunk_to_store = b.create(reshape_type, packed_chunk); + } + store_indices[stride_dim] = + b.create(split_base_idx, IdxConst(i, b, loc)); + + b.create(chunk_to_store, i32_view, store_indices, + strides); + } + + store_op->erase(); + shape_cast_op->erase(); +} + struct RhsTraversalResult { tpu::TransposeOp transpose_op = nullptr; vector::ExtractStridedSliceOp slice_op = nullptr; @@ -180,7 +366,13 @@ tryFuseRhsTranspose(tpu::MatmulOp op, ImplicitLocOpBuilder& builder) { struct PreCanonicalizationOptimizationPass : impl::PreCanonicalizationOptimizationPassBase< PreCanonicalizationOptimizationPass> { + PreCanonicalizationOptimizationPass(int hardware_generation_p, + std::array target_shape_p) + : hardware_generation_(hardware_generation_p), + target_shape_(target_shape_p) {} + void runOnOperation() override { + // Calculate target shape from pass parameters getOperation().walk([&](tpu::MatmulOp op) { // We only attempt this fusion if dimension numbers are present. if (!op.getDimensionNumbers().has_value()) { @@ -196,14 +388,31 @@ struct PreCanonicalizationOptimizationPass op.setDimensionNumbersAttr(new_dnums); } }); + + // Apply store canonicalization + getOperation().walk([&](vector::StoreOp op) { + CanonicalizeStore(hardware_generation_, target_shape_, + *op.getOperation()); + }); + + getOperation().walk([&](tpu::VectorStoreOp op) { + CanonicalizeStore(hardware_generation_, target_shape_, + *op.getOperation()); + }); } + + private: + int64_t hardware_generation_; + std::array target_shape_; }; } // namespace std::unique_ptr> -createPreCanonicalizationOptimizationPass() { - return std::make_unique(); +createPreCanonicalizationOptimizationPass(int hardware_generation, + std::array target_shape) { + return std::make_unique( + hardware_generation, target_shape); } -} // namespace mlir::tpu \ No newline at end of file +} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 4bcb8a6eeb85..b75e863face4 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -109,6 +109,38 @@ FailureOr> computeSqueezedDimsChecked( return squeezed; } +std::optional> findSplitPoint( + ArrayRef src_shape, ArrayRef tgt_shape) { + int64_t s = 0, t = 0; + while (s < src_shape.size() && src_shape[s] == 1) { + ++s; + } + while (t < tgt_shape.size() && tgt_shape[t] == 1) { + ++t; + } + + int64_t s_prefix_end = s, t_prefix_end = t; + while (s_prefix_end < src_shape.size() && t_prefix_end < tgt_shape.size() && + src_shape[s_prefix_end] == tgt_shape[t_prefix_end]) { + ++s_prefix_end; + ++t_prefix_end; + } + + if (t_prefix_end != tgt_shape.size() - 1) { + return std::nullopt; + } + int64_t src_prod = 1; + for (int64_t i = s_prefix_end; i < src_shape.size(); ++i) { + src_prod *= src_shape[i]; + } + + if (tgt_shape.back() != src_prod) { + return std::nullopt; + } + src_prod /= src_shape.back(); + return std::make_pair(s_prefix_end, src_prod); +} + std::optional> isTransposedMatmul( DotDimensionNumbersAttr dim_numbers) { auto lhs_contracting_dims = dim_numbers.getLhsContractingDims(); diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 20dd83918f3d..18eee8cc4657 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -228,6 +228,9 @@ FailureOr> computeSqueezedDimsChecked( Operation *op, ArrayRef source_shape, ArrayRef target_shape); +std::optional> findSplitPoint( + ArrayRef src_shape, ArrayRef tgt_shape); + // Assuming MKN matmul - This function must only be called after // canonicalization passes. //