From e0b92bb5788953a8bf21868afa42df3d73709ca4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 7 Apr 2025 13:26:13 -0700 Subject: [PATCH 1/3] Drop failure case for `stablehlo.dynamic_broadcast_in_dim` The failure to broadcast dynamically makes the assumption the input dynamic shape could be expanded by being `1`. This should be handled by an earlier trasform to materialize a known broadcast if we intend to support both cases. --- .../linalg/tests/miscellaneous.mlir | 18 ++++++++++++++++++ .../transforms/StablehloLegalizeToLinalg.cpp | 3 +-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index 3cc398a6e9..03ada714e4 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -206,6 +206,24 @@ func.func @constant() -> tensor { // ----- +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @dynamic_broadcast +// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[C1]]] : tensor<2xi32> +// CHECK: %[[CAST:.+]] = arith.index_cast %[[EXTRACT]] : i32 to index +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[CAST]]) : tensor<1x?xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor) outs(%[[EMPTY]] : tensor<1x?xf32>) +func.func public @dynamic_broadcast(%arg0: tensor, %arg1 : tensor<2xi32>) -> (tensor<1x?xf32>) { + %c = stablehlo.constant dense<1> : tensor<1xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [1] : (tensor, tensor<2xi32>) -> tensor<1x?xf32> + return %4 : tensor<1x?xf32> +} + +// ----- + // CHECK-LABEL: func @elided_constant // CHECK: %[[CONSTANT:.*]] = arith.constant dense_resource<__elided__> : tensor<1024xf32> func.func @elided_constant() -> tensor<1024xf32> { diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 45f3cbbb1c..d75bcad079 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -678,8 +678,7 @@ struct HloDynamicBroadcastInDimConverter final // Use static type info. auto bcastDims = op.getBroadcastDimensions(); for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) { - if (ShapedType::isDynamic(dim)) continue; - + // We can assume if the input is dynamic it is not expanding. bool isExpanding = dim == 1; dimExprs[idx] = isExpanding ? rewriter.getAffineConstantExpr(0) : rewriter.getAffineDimExpr(bcastDims[idx]); From 6bce1216185f1be4d9499a2c4a6e00699cd372e3 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 7 Apr 2025 14:07:18 -0700 Subject: [PATCH 2/3] fix test --- stablehlo/conversions/linalg/tests/miscellaneous.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index 03ada714e4..2f156293b6 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -206,8 +206,8 @@ func.func @constant() -> tensor { // ----- -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @dynamic_broadcast // CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] From fafb83dccd340efca188f4507e50ca6a437e1a41 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 7 Apr 2025 14:15:32 -0700 Subject: [PATCH 3/3] remove label to fix checks --- stablehlo/conversions/linalg/tests/miscellaneous.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index 2f156293b6..aa26cd0c86 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -209,7 +209,7 @@ func.func @constant() -> tensor { // CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)> // CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: @dynamic_broadcast +// CHECK: @dynamic_broadcast // CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[C1]]] : tensor<2xi32>