diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index 4f1fd660a6..10d894086a 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -1,5 +1,6 @@ // RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg --split-input-file --canonicalize | FileCheck %s // RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg="enable-primitive-ops=true" --split-input-file --canonicalize | FileCheck %s --check-prefix=CHECK-PRIMITIVE +// RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg="enable-primitive-ops=true" --split-input-file | FileCheck %s --check-prefix=CHECK-PRIMITIVE-2 // CHECK-LABEL: func @bitcast_convert func.func @bitcast_convert(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { @@ -1672,3 +1673,15 @@ func.func @transpose_unsigned(%arg0: tensor<2x2xui32>) -> tensor<2x2xui32> { // Regression test. Just check that unsigned ints lower successfully. // CHECK-LABEL: func @transpose_unsigned // CHECK-PRIMITIVE-LABEL: func @transpose_unsigned + +// ----- + +func.func @dynamic_broadcast(%arg0: tensor, %arg1: tensor<3xindex>) -> (tensor) { + %0 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor + return %0 : tensor +} + +// CHECK-PRIMITIVE-2: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor) outs(%[[ARG1:.*]] : tensor) { +// CHECK-PRIMITIVE-2: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-PRIMITIVE-2: linalg.yield %[[IN]] : f32 +// CHECK-PRIMITIVE-2: } -> tensor diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index eb9b5198ef..f427abc775 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -32,6 +32,7 @@ limitations under the License. #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -845,6 +846,61 @@ struct DynamicBroadcastInDimOpToBroadcastConverter final } }; +// If the input and output have a dynamic shape such that the rank of input +// matches the rank of output and shapes are compatible (i.e. no broadcasting is +// needed), then the dynamic broadcast in dim is effectively a copying input to +// output In such cases, we can lower it to a simple linalg.generic operation. +// It can later be canonicalised and bufferised accordingly. +struct DynamicBroadcastInDimOpDynamicShapeConverter final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::stablehlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto loc = op.getLoc(); + Value input = op.getOperand(); + Value output = op.getResult(); + auto input_type = dyn_cast(input.getType()); + auto output_type = dyn_cast(output.getType()); + assert(input_type && output_type && "expected shaped type"); + if (input_type.getRank() == output_type.getRank() && + input_type.getShape() == output_type.getShape() && + op.getBroadcastDimensions().size() == input_type.getRank()) { + if (input_type.hasStaticShape()) { + return failure(); + } + if (!llvm::all_of(llvm::seq(0, input_type.getRank()), [&](int i) { + return i == op.getBroadcastDimensions()[i]; + })) { + return failure(); + } + } + auto mixedSize = tensor::getMixedSizes(rewriter, loc, input); + auto resultType = cast(output.getType()); + auto emptyTensor = rewriter.create( + loc, mixedSize, resultType.getElementType()); + auto map = AffineMap::getMultiDimIdentityMap(resultType.getRank(), + rewriter.getContext()); + + SmallVector indexingMaps{map, map}; + SmallVector inputs{input}; + SmallVector outputs{emptyTensor}; + auto generic = + rewriter + .create( + loc, TypeRange{resultType}, inputs, outputs, indexingMaps, + SmallVector( + resultType.getRank(), utils::IteratorType::parallel), + [&](OpBuilder& builder, Location loc, ValueRange blockArgs) { + builder.create(loc, blockArgs[0]); + }) + ->getResults(); + rewriter.replaceOp(op, ValueRange{generic}); + return success(); + } +}; + template struct TransposeConverter final : DataMovementOpConverter, OpTy> { @@ -2694,6 +2750,7 @@ void populateStablehloToLinalgConversionPatterns(MLIRContext* context, BroadcastInDimOpToBroadcastConverter, BroadcastOpToBroadcastConverter, DynamicBroadcastInDimOpToBroadcastConverter, + DynamicBroadcastInDimOpDynamicShapeConverter, IotaToMapConverter, IotaToMapConverter, MapOpToMapConverter,