From 1b7e9c028999167e063ec33bb7946e2da97c7eb4 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Wed, 23 Dec 2020 08:04:58 -0800 Subject: [PATCH] Upstream mhlo.dot lowering to Linalg to MHLO repo. We prototyped the lowering from mhlo.dot to linalg.matmul in IREE. Since Linalg now supports matmul in tensors world, we can move the lowering logic to tensors world, and upstream to legalize_to_linalg.cc. The patch lowers the mhlo.dot to the linalg.matmul/matvec/dot in tensors world. PiperOrigin-RevId: 348796369 --- test/from_linalg_invalid.mlir | 2 +- transforms/sair_from_linalg.cc | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/from_linalg_invalid.mlir b/test/from_linalg_invalid.mlir index 0c18b29b..bfb010b2 100644 --- a/test/from_linalg_invalid.mlir +++ b/test/from_linalg_invalid.mlir @@ -11,7 +11,7 @@ } func @reductions(%arg0: memref<2x3x4x5x6xf32>, %arg1: memref<2x4x6xf32>) { - // expected-error @+1 {{Linalg op is not compatible with Sair}} + // expected-error @+1 {{unexpected output tensor expression in indexing map #0 a.k.a 'd3' is function of reduction iterator 'd3'}} linalg.generic #reductions_trait ins(%arg0 : memref<2x3x4x5x6xf32>) outs(%arg1 : memref<2x4x6xf32>) { diff --git a/transforms/sair_from_linalg.cc b/transforms/sair_from_linalg.cc index be99e55c..a14eafad 100644 --- a/transforms/sair_from_linalg.cc +++ b/transforms/sair_from_linalg.cc @@ -431,7 +431,7 @@ void MoveBodyBlock(mlir::AffineMap linalg_to_sair_loops, // as "types" and a hyper-rectangular domain with the given number of dimensons. // Uses "rewriter" to construct the types. void CreateResultTypes(mlir::Builder &rewriter, int num_dimensions, - mlir::TypeRange types, + const SmallVectorImpl &types, llvm::SmallVectorImpl &result_types) { mlir::MLIRContext *context = rewriter.getContext(); auto result_domain_shape = @@ -562,7 +562,7 @@ mlir::LogicalResult RewriteLinalgToSair(mlir::linalg::LinalgOp op, // Linalg does not seem to restrict the output indexing to parallel dimensions // only, but Sair does. Abort the conversion in case of incompatibility. int num_parallel_loops = op.getNumParallelLoops(); - int num_operands = op.getNumInputsAndOutputBuffers(); + int num_operands = op.getNumShapedOperands(); for (int i = op.getNumInputs(); i < num_operands; ++i) { auto mapping = operand_mappings[i].cast(); if (mlir::failed(VerifyReductionMapping(mapping, num_parallel_loops))) { @@ -589,20 +589,20 @@ mlir::LogicalResult RewriteLinalgToSair(mlir::linalg::LinalgOp op, // Convert input and input/output MemRefs used by Linalg to Sair values. llvm::SmallVector map_operands; llvm::SmallVector, 4> result_ranges; - EmitMemRefToValue(op.getInputsAndOutputBuffers(), op.getNumOutputs(), loc, + EmitMemRefToValue(op.getShapedOperands(), op.getNumOutputs(), loc, sair_program, rewriter, map_operands, result_ranges); // Prepare parameters of the Sair map operation. int num_loops = op.getNumLoops(); llvm::SmallVector loop_bounds; - CollectLoopBounds(num_loops, subscripts_to_loops, - op.getInputsAndOutputBuffers(), loop_bounds); + CollectLoopBounds(num_loops, subscripts_to_loops, op.getShapedOperands(), + loop_bounds); llvm::SmallVector domain_ranges = CreateSairDomain(loc, loop_bounds, sair_program, rewriter); llvm::SmallVector result_types; - CreateResultTypes(rewriter, num_parallel_loops, - op.getOutputBuffers().getTypes(), result_types); + CreateResultTypes(rewriter, num_parallel_loops, op.getOutputBufferTypes(), + result_types); // Construct the main map or map_reduce operation. mlir::Operation *map_op;