Skip to content

Commit 4227a87

Browse files
authored
1 parent 0649133 commit 4227a87

File tree

18 files changed

+298
-141
lines changed

18 files changed

+298
-141
lines changed

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "113f01aa82d055410f22a9d03b3468fa68600589"
20+
LLVM_COMMIT = "3a6b818132e3133c7d33f8f577e62503f12869b4"
2121

22-
LLVM_SHA256 = "9aee00a35aa76639746589c6d09e8c18249be16b5b6aa6b788a570a4bc6c4543"
22+
LLVM_SHA256 = "a0b3de698393e0f49d0aca3f869cc03bf0c59eba0c65f608e565278943c31958"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
113f01aa82d055410f22a9d03b3468fa68600589
1+
3a6b818132e3133c7d33f8f577e62503f12869b4

docs/generated/stablehlo_linalg_passes.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ _Legalize StableHLO to LinAlg_
77
#### Options
88

99
```
10-
-enable-primitive-ops : Lower to primitive Linalg ops (map, reduce and transpose) when possible, instead of linalg.generic
11-
-enable-sparse-ops : Lower to Sparse Tensor ops (sparse_tensor.concatenate)when possible, instead of linalg.generic
10+
-enable-primitive-ops : Lower to primitive Linalg ops (map, reduce and transpose) when possible, instead of linalg.generic
11+
-enable-sparse-ops : Lower to Sparse Tensor ops (sparse_tensor.concatenate)when possible, instead of linalg.generic
12+
-capture-scalar-inputs : Capture scalar inputs in generic ops instead ofpassing as tensor-scalar argument.
1213
```

docs/generated/stablehlo_optimization_passes.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ _Folds StableHLO operations_
88

99
```
1010
-assume-no-undeclared-side-effects : Allow dead code to be eliminated in some situations (e.g. dead while loops) under the assumption that ops are pure unless declared with explicit MLIR `MemoryEffects`. Notably, this means `func.call` ops will be assumed pure.
11-
-fold-op-element-limit : Folding an op into a constant can sometimes come at the cost of memory overhead. (This occurs if the op's inputs are reused, meaning that they can't be deleted after the op is folded to a constant, or when folding operations like `iota` whose outputs take up more memory than their inputs.) In such cases, this config option sets an upper limit on how many elements an op's result may have before the op is no longer folded.
11+
-fold-op-element-limit : Folding an op into a constant can sometimes come at the cost of memory overhead. (This occurs if the op's inputs are reused, meaning that they can't be deleted after the op is folded to a constant, or when folding operations like `concat` whose outputs take up more memory than their inputs.) In such cases, this config option sets an upper limit on how many elements an op's result may have before the op is no longer folded. Splat folds are exempt from this limit.
1212
-optimize-float : Allow float optimizations that, though mathematically equivalent, may result in slightly different quantization of floating-point values (e.g. `log(sqrt(x))` -> `0.5 * log(x)`). Float optimizations that can't affect numerical results are always enabled.
1313
```
1414

@@ -105,7 +105,7 @@ high coverage of the pass today.
105105
#### Options
106106

107107
```
108-
-fold-op-element-limit : Folding an op into a constant can sometimes come at the cost of memory overhead. (This occurs if the op's inputs are reused, meaning that they can't be deleted after the op is folded to a constant, or when folding operations like `iota` whose outputs take up more memory than their inputs.) In such cases, this config option sets an upper limit on how many elements an op's result may have before the op is no longer folded.
108+
-fold-op-element-limit : Folding an op into a constant can sometimes come at the cost of memory overhead. (This occurs if the op's inputs are reused, meaning that they can't be deleted after the op is folded to a constant, or when folding operations like `concat` whose outputs take up more memory than their inputs.) In such cases, this config option sets an upper limit on how many elements an op's result may have before the op is no longer folded. Splat folds are exempt from this limit.
109109
```
110110

111111
### `-stablehlo-target-independent-optimization`
@@ -123,6 +123,6 @@ Users should prefer this pass to calling the others directly.
123123

124124
```
125125
-assume-no-undeclared-side-effects : Allow dead code to be eliminated in some situations (e.g. dead while loops) under the assumption that ops are pure unless declared with explicit MLIR `MemoryEffects`. Notably, this means `func.call` ops will be assumed pure.
126-
-fold-op-element-limit : Folding an op into a constant can sometimes come at the cost of memory overhead. (This occurs if the op's inputs are reused, meaning that they can't be deleted after the op is folded to a constant, or when folding operations like `iota` whose outputs take up more memory than their inputs.) In such cases, this config option sets an upper limit on how many elements an op's result may have before the op is no longer folded.
126+
-fold-op-element-limit : Folding an op into a constant can sometimes come at the cost of memory overhead. (This occurs if the op's inputs are reused, meaning that they can't be deleted after the op is folded to a constant, or when folding operations like `concat` whose outputs take up more memory than their inputs.) In such cases, this config option sets an upper limit on how many elements an op's result may have before the op is no longer folded. Splat folds are exempt from this limit.
127127
-optimize-float : Allow float optimizations that, though mathematically equivalent, may result in slightly different quantization of floating-point values (e.g. `log(sqrt(x))` -> `0.5 * log(x)`). Float optimizations that can't affect numerical results are always enabled.
128128
```

stablehlo/conversions/linalg/tests/pointwise.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg --split-input-file --canonicalize | FileCheck %s
22
// RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg="enable-primitive-ops=true" --split-input-file --canonicalize | FileCheck %s --check-prefix=CHECK-PRIMITIVE
3+
// RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg="capture-scalar-inputs=false" --split-input-file --canonicalize | FileCheck %s --check-prefix=CHECK-NO-CAPTURE
34

45
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
56
// CHECK-LABEL: func @float_add
@@ -538,6 +539,19 @@ func.func @complex_sign(
538539

539540
// -----
540541

542+
// CHECK-LABEL: func @float_tan
543+
// CHECK-PRIMITIVE-LABEL: func @float_tan
544+
func.func @float_tan(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
545+
// CHECK: linalg.generic
546+
// CHECK: tan
547+
// CHECK-PRIMITIVE: linalg.map
548+
// CHECK-PRIMITIVE: tan
549+
%0 = "stablehlo.tan"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
550+
func.return %0 : tensor<2x2xf32>
551+
}
552+
553+
// -----
554+
541555
// CHECK-LABEL: func @float_tanh
542556
// CHECK-PRIMITIVE-LABEL: func @float_tanh
543557
func.func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
@@ -927,6 +941,23 @@ func.func @select_scalar_pred_dyn(%pred : tensor<i1>, %lhs: tensor<2x?xf32>, %rh
927941
// CHECK-PRIMITIVE: %[[RES:.*]] = arith.select %[[PRED_ELEM]], %[[LHS_]], %[[RHS_]] : f32
928942
// CHECK-PRIMITIVE: linalg.yield %[[RES]]
929943

944+
// CHECK-NO-CAPTURE: #[[SCALAR_MAP:.*]] = affine_map<(d0, d1) -> ()>
945+
// CHECK-NO-CAPTURE: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
946+
// CHECK-NO-CAPTURE: func @select_scalar_pred_dyn
947+
// CHECK-NO-CAPTURE-SAME: (%[[PRED:.*]]: tensor<i1>, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>)
948+
// CHECK-NO-CAPTURE-DAG: %[[C1:.*]] = arith.constant 1
949+
// CHECK-NO-CAPTURE-DAG: %[[DIM:.*]] = tensor.dim %[[LHS]], %[[C1]]
950+
// CHECK-NO-CAPTURE-DAG: %[[DST:.*]] = tensor.empty(%[[DIM]])
951+
// CHECK-NO-CAPTURE: linalg.generic
952+
// CHECK-NO-CAPTURE-SAME: indexing_maps = [#[[SCALAR_MAP]], #[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]]
953+
// CHECK-NO-CAPTURE-SAME: iterator_types = ["parallel", "parallel"]
954+
// CHECK-NO-CAPTURE-SAME: ins(%[[PRED]], %[[LHS]], %[[RHS]] : tensor<i1>, tensor<2x?xf32>, tensor<2x?xf32>)
955+
// CHECK-NO-CAPTURE-SAME: outs(%[[DST]] : tensor<2x?xf32>)
956+
// CHECK-NO-CAPTURE-SAME: {someattr}
957+
// CHECK-NO-CAPTURE: ^bb0(%[[PRED_:.*]]: i1, %[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %{{.*}}: f32):
958+
// CHECK-NO-CAPTURE: %[[RES:.*]] = arith.select %[[PRED_]], %[[LHS_]], %[[RHS_]] : f32
959+
// CHECK-NO-CAPTURE: linalg.yield %[[RES]]
960+
930961
// -----
931962

932963
// CHECK: func @select_scalar_pred_static

stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,11 @@ Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
140140
// (any sign-op, or an integral abs-op).
141141
// TODO(peiming, ajcbik): these all can potentially be optimized by applying
142142
// value transform on sparse_tenosr.value memref
143-
if (isa<mlir::stablehlo::SignOp>(op) || isa<mlir::stablehlo::NegOp>(op) ||
143+
if (isa<mlir::stablehlo::SignOp, mlir::stablehlo::NegOp,
144+
mlir::stablehlo::TanOp>(op) ||
144145
(isa<mlir::stablehlo::AbsOp>(op) && hasIntegralShapeType(op)) ||
145-
isa<chlo::AsinOp>(op) || isa<chlo::AsinhOp>(op) ||
146-
isa<chlo::AtanOp>(op) || isa<chlo::AtanhOp>(op) ||
147-
isa<chlo::BesselI1eOp>(op) || isa<chlo::SinhOp>(op) ||
148-
isa<chlo::TanOp>(op)) {
146+
isa<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp, chlo::AtanhOp,
147+
chlo::BesselI1eOp, chlo::SinhOp, chlo::TanOp>(op)) {
149148
if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) &&
150149
!sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType()))
151150
return Value();

0 commit comments

Comments
 (0)