diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index 3cc398a6e9..aa26cd0c86 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -206,6 +206,24 @@ func.func @constant() -> tensor { // ----- +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: @dynamic_broadcast +// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[C1]]] : tensor<2xi32> +// CHECK: %[[CAST:.+]] = arith.index_cast %[[EXTRACT]] : i32 to index +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[CAST]]) : tensor<1x?xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor) outs(%[[EMPTY]] : tensor<1x?xf32>) +func.func public @dynamic_broadcast(%arg0: tensor, %arg1 : tensor<2xi32>) -> (tensor<1x?xf32>) { + %c = stablehlo.constant dense<1> : tensor<1xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [1] : (tensor, tensor<2xi32>) -> tensor<1x?xf32> + return %4 : tensor<1x?xf32> +} + +// ----- + // CHECK-LABEL: func @elided_constant // CHECK: %[[CONSTANT:.*]] = arith.constant dense_resource<__elided__> : tensor<1024xf32> func.func @elided_constant() -> tensor<1024xf32> { diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 45f3cbbb1c..d75bcad079 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -678,8 +678,7 @@ struct HloDynamicBroadcastInDimConverter final // Use static type info. auto bcastDims = op.getBroadcastDimensions(); for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) { - if (ShapedType::isDynamic(dim)) continue; - + // We can assume if the input is dynamic it is not expanding. bool isExpanding = dim == 1; dimExprs[idx] = isExpanding ? rewriter.getAffineConstantExpr(0) : rewriter.getAffineDimExpr(bcastDims[idx]);