Skip to content

Commit 4bf77d2

Browse files
authored
1 parent a3c7de9 commit 4bf77d2

15 files changed

+377
-42
lines changed

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "ac9049df7e62e2ca4dc5d103593b51639b5715e3"
20+
LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260"
2121

22-
LLVM_SHA256 = "ea890ee3c13d9b2d70a359299a0b810c8bae9c729c5a94d81f5b304bf26f34b6"
22+
LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ac9049df7e62e2ca4dc5d103593b51639b5715e3
1+
799e9053641a6478d3144866a97737b37b87c260

docs/generated/chlo.md

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,14 +1968,6 @@ Syntax:
19681968
>
19691969
```
19701970

1971-
Enum cases:
1972-
* EQ (`EQ`)
1973-
* NE (`NE`)
1974-
* GE (`GE`)
1975-
* GT (`GT`)
1976-
* LE (`LE`)
1977-
* LT (`LT`)
1978-
19791971
#### Parameters:
19801972

19811973
| Parameter | C++ type | Description |
@@ -1994,13 +1986,6 @@ Syntax:
19941986
>
19951987
```
19961988

1997-
Enum cases:
1998-
* NOTYPE (`NOTYPE`)
1999-
* FLOAT (`FLOAT`)
2000-
* TOTALORDER (`TOTALORDER`)
2001-
* SIGNED (`SIGNED`)
2002-
* UNSIGNED (`UNSIGNED`)
2003-
20041989
#### Parameters:
20051990

20061991
| Parameter | C++ type | Description |
@@ -2019,11 +2004,6 @@ Syntax:
20192004
>
20202005
```
20212006

2022-
Enum cases:
2023-
* DEFAULT (`DEFAULT`)
2024-
* HIGH (`HIGH`)
2025-
* HIGHEST (`HIGHEST`)
2026-
20272007
#### Parameters:
20282008

20292009
| Parameter | C++ type | Description |

docs/generated/stablehlo_passes.md

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,26 @@ func.func @add(%arg0: tensor<!quant.uniform<i8:f32, 1.000000e+00>>, %arg1: tenso
259259

260260
_Legalize StableHLO to VHLO._
261261

262+
Legalize StableHLO to the latest version of ops in VHLO. These ops can then
263+
be downgraded to older versions of VHLO for forward compatibility using
264+
`VhloToVersionPass`.
265+
266+
```mlir
267+
stablehlo.exponential %[[ARG0]] <{result_accuracy = DEFAULT}> : tensor<f32>
268+
# ====>
269+
"vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.DEFAULT_v1}> : !vhlo.tensor_v1<!vhlo.f32_v1>
270+
```
271+
272+
See [vhlo.md > The VHLO dialect](https://github.com/openxla/stablehlo/blob/main/docs/vhlo.md)
273+
for full details on how VHLO is used to preserve forward and backward
274+
compatibility.
275+
276+
#### Options
277+
278+
```
279+
-allow-other-dialects : Allow serialization to use other (potentially unstable) dialects, inserts unrealized casts between dialects.
280+
```
281+
262282
### `-stablehlo-refine-arguments`
263283

264284
_Refines the argument shapes of the main function._
@@ -279,6 +299,7 @@ func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
279299
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...}
280300
: (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
281301
...
302+
}
282303
```
283304

284305
The `refinedTypesOption` can be used to specify a list of refined types.
@@ -459,7 +480,22 @@ _Legalize VHLO to StableHLO._
459480

460481
### `-vhlo-to-version`
461482

462-
_Convert between versions of VHLO._
483+
_Convert between versions of VHLO for compatibility._
484+
485+
Converts between versions of VHLO for IR upgrading and downgrading to
486+
preserve forward and backward compatibility.
487+
488+
```mlir
489+
"vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = DEFAULT}>
490+
# ==( -target=1.0.0 )==>
491+
"vhlo.exponential_v1"(%[[ARG0]])
492+
# ==( -target=1.9.0 )==>
493+
"vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = DEFAULT}>
494+
```
495+
496+
See [vhlo.md > The VHLO dialect](https://github.com/openxla/stablehlo/blob/main/docs/vhlo.md)
497+
for full details on how VHLO is used to preserve forward and backward
498+
compatibility.
463499

464500
#### Options
465501

stablehlo/dialect/Serialization.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,25 @@ limitations under the License.
3232
#include "stablehlo/dialect/VhloOps.h"
3333
#include "stablehlo/transforms/Passes.h"
3434

35-
#define DEBUG_TYPE "compat-passes"
35+
#define DEBUG_TYPE "stablehlo-compat"
3636

3737
namespace mlir {
3838
namespace stablehlo {
3939

4040
LogicalResult serializePortableArtifact(ModuleOp module,
4141
StringRef targetVersion,
42-
raw_ostream& os) {
42+
raw_ostream& os,
43+
bool allowOtherDialects) {
4344
MLIRContext* context = module.getContext();
4445

45-
// Convert StableHLO --> VHLO. Will fail if entire program is not StableHLO.
46+
// Convert StableHLO --> VHLO.
47+
// If allowOtherDialects is true, we will allow other dialects to be present
48+
// in the module, otherwise will fail if there are any other dialects present.
4649
{
4750
PassManager pm(context);
48-
pm.addPass(stablehlo::createStablehloLegalizeToVhloPass());
51+
StablehloLegalizeToVhloPassOptions options;
52+
options.allowOtherDialects = allowOtherDialects;
53+
pm.addPass(stablehlo::createStablehloLegalizeToVhloPass(options));
4954
if (!succeeded(pm.run(module))) {
5055
return failure();
5156
}

stablehlo/dialect/Serialization.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ namespace stablehlo {
3434
// unsupported dialects.
3535
LogicalResult serializePortableArtifact(ModuleOp module,
3636
StringRef targetVersion,
37-
raw_ostream& os);
37+
raw_ostream& os,
38+
bool allowOtherDialects = false);
3839

3940
// Read StableHLO portable artifact
4041
//

stablehlo/dialect/VhloTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,20 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
322322
});
323323
}
324324

325+
namespace {
326+
Value materializeIllegalCast(OpBuilder& builder, Type type, ValueRange inputs,
327+
Location loc) {
328+
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
329+
->getResult(0);
330+
}
331+
} // namespace
332+
333+
void VhloTypeConverter::addUnrealizedMaterializations() {
334+
addTargetMaterialization(materializeIllegalCast);
335+
addSourceMaterialization(materializeIllegalCast);
336+
addArgumentMaterialization(materializeIllegalCast);
337+
}
338+
325339
namespace {
326340
// Helper functions for VHLO verifiers
327341
template <typename TypeOrAttr>

stablehlo/dialect/VhloTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class VhloTypeConverter : public VhloTypeConverterBase {
5555
// it is likely that a subclass should call this last, especially if a default
5656
// `Type -> Type` fallback conversion is registered.
5757
void addBuiltinToVhloConversions();
58+
59+
// Mark unrealized casts as legal. Useful for dialect mixing.
60+
void addUnrealizedMaterializations();
5861
};
5962

6063
// Autogenerated VHLO type printers and parsers.

stablehlo/tests/ops_stablehlo_quantized.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,15 +1338,15 @@ func.func @quantized_element_type_c8(%arg0: tensor<1x2x!quant.uniform<i8<-128:12
13381338

13391339
// -----
13401340

1341-
// expected-error@+1 {{scale out of expressed type range}}
1341+
// expected-error@+1 {{scale 1.055040e+05 out of expressed type range}}
13421342
func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform<i4:f16, 10.550400e+04>>) {
13431343
%0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform<i4:f16, 10.550400e+04>>
13441344
func.return
13451345
}
13461346

13471347
// -----
13481348

1349-
// expected-error@+1 {{scale out of expressed type range}}
1349+
// expected-error@+1 {{scale 4.960464e-08 out of expressed type range}}
13501350
func.func @quantized_element_type_c6(%arg0: tensor<1x2x!quant.uniform<i4:f16, 4.960464e-08>>) {
13511351
%0 = stablehlo.add %arg0, %arg0 : tensor<1x2x!quant.uniform<i4:f16, 4.960464e-08>>
13521352
func.return
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
2+
3+
// The script is designed to make adding checks to
4+
// a test case fast, it is *not* designed to be authoritative
5+
// about what constitutes a good test! The CHECK should be
6+
// minimized and named to reflect the test intent.
7+
8+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
9+
10+
// The script is designed to make adding checks to
11+
// a test case fast, it is *not* designed to be authoritative
12+
// minimized and named to reflect the test intent.
13+
14+
// RUN: stablehlo-opt %s --stablehlo-legalize-to-vhlo=allow-other-dialects | FileCheck %s
15+
// RUN: stablehlo-opt %s > %t.0
16+
// RUN: stablehlo-opt %s --stablehlo-legalize-to-vhlo=allow-other-dialects | stablehlo-opt --vhlo-legalize-to-stablehlo > %t.1
17+
// RUN: diff %t.0 %t.1
18+
19+
// CHECK-LABEL: vhlo.func_v1 @op_other(
20+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
21+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1<!vhlo.f32_v1> to tensor<f32>
22+
// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_1]], %[[VAL_1]] : tensor<f32>
23+
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor<f32> to !vhlo.tensor_v1<!vhlo.f32_v1>
24+
// CHECK: "vhlo.return_v1"(%[[VAL_3]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
25+
// CHECK: }
26+
func.func @op_other(%arg0: tensor<f32>) -> tensor<f32> {
27+
%0 = arith.addf %arg0, %arg0 : tensor<f32>
28+
return %0 : tensor<f32>
29+
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: vhlo.func_v1 @op_shlo(
34+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
35+
// CHECK: %[[VAL_1:.*]] = "vhlo.add_v1"(%[[VAL_0]], %[[VAL_0]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
36+
// CHECK: "vhlo.return_v1"(%[[VAL_1]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
37+
// CHECK: }
38+
func.func @op_shlo(%arg0: tensor<f32>) -> tensor<f32> {
39+
%0 = stablehlo.add %arg0, %arg0 : tensor<f32>
40+
return %0 : tensor<f32>
41+
}
42+
43+
// -----
44+
45+
// CHECK-LABEL: vhlo.func_v1 @mixed_shlo_other_shlo(
46+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
47+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1<!vhlo.f32_v1> to tensor<f32>
48+
// CHECK: %[[VAL_2:.*]] = "vhlo.abs_v1"(%[[VAL_0]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
49+
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : !vhlo.tensor_v1<!vhlo.f32_v1> to tensor<f32>
50+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_3]], %[[VAL_1]] : tensor<f32>
51+
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : tensor<f32> to !vhlo.tensor_v1<!vhlo.f32_v1>
52+
// CHECK: %[[VAL_6:.*]] = "vhlo.abs_v1"(%[[VAL_5]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
53+
// CHECK: "vhlo.return_v1"(%[[VAL_6]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
54+
// CHECK: }
55+
func.func @mixed_shlo_other_shlo(%arg0: tensor<f32>) -> tensor<f32> {
56+
%0 = stablehlo.abs %arg0 : tensor<f32>
57+
%1 = arith.addf %0, %arg0 : tensor<f32>
58+
%2 = stablehlo.abs %1 : tensor<f32>
59+
return %2 : tensor<f32>
60+
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: vhlo.func_v1 @mixed_other_shlo_other(
65+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.f32_v1>) {
66+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1<!vhlo.f32_v1> to tensor<f32>
67+
// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_1]], %[[VAL_1]] : tensor<f32>
68+
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor<f32> to !vhlo.tensor_v1<!vhlo.f32_v1>
69+
// CHECK: %[[VAL_4:.*]] = "vhlo.add_v1"(%[[VAL_3]], %[[VAL_0]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
70+
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1<!vhlo.f32_v1> to tensor<f32>
71+
// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[VAL_1]] : tensor<f32>
72+
// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : tensor<f32> to !vhlo.tensor_v1<!vhlo.f32_v1>
73+
// CHECK: "vhlo.return_v1"(%[[VAL_7]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
74+
// CHECK: }
75+
func.func @mixed_other_shlo_other(%arg0: tensor<f32>) -> tensor<f32> {
76+
%0 = arith.addf %arg0, %arg0 : tensor<f32>
77+
%1 = stablehlo.add %0, %arg0 : tensor<f32>
78+
%2 = arith.addf %1, %arg0 : tensor<f32>
79+
return %2 : tensor<f32>
80+
}
81+
82+
// -----
83+
84+
// CHECK-LABEL: vhlo.func_v1 @op_with_region(
85+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<1x16x16x320x!vhlo.f32_v1>,
86+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<1x320x!vhlo.f32_v1>) {
87+
// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}}<dense<[1, 2]> : tensor<2xi64>>}> ({
88+
// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1<!vhlo.f32_v1>, %[[VAL_4:.*]]: !vhlo.tensor_v1<!vhlo.f32_v1>):
89+
// CHECK: %[[VAL_5:.*]] = "vhlo.add_v1"(%[[VAL_3]], %[[VAL_4]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
90+
// CHECK: "vhlo.return_v1"(%[[VAL_5]]) : (!vhlo.tensor_v1<!vhlo.f32_v1>) -> ()
91+
// CHECK: }) : (!vhlo.tensor_v1<1x16x16x320x!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x320x!vhlo.f32_v1>
92+
// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<1x320x!vhlo.f32_v1>) -> ()
93+
// CHECK: }
94+
func.func @op_with_region(%arg0: tensor<1x16x16x320xf32>, %arg1: tensor<f32>) -> tensor<1x320xf32> {
95+
%0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.add across dimensions = [1, 2] : (tensor<1x16x16x320xf32>, tensor<f32>) -> tensor<1x320xf32>
96+
return %0 : tensor<1x320xf32>
97+
}
98+
99+
// -----
100+
101+
// CHECK-LABEL: vhlo.func_v1 @op_with_region_mixed_other_shlo_other(
102+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<7x5x!vhlo.f32_v1>,
103+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<5x!vhlo.f32_v1>) {
104+
// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}}<dense<0> : tensor<1xi64>>}> ({
105+
// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>, %[[VAL_4:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>):
106+
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32>
107+
// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32>
108+
// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_6]], %[[VAL_5]] : tensor<5xf32>
109+
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1>
110+
// CHECK: %[[VAL_9:.*]] = "vhlo.add_v1"(%[[VAL_8]], %[[VAL_3]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1>
111+
// CHECK: %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32>
112+
// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_5]] : tensor<5xf32>
113+
// CHECK: %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1>
114+
// CHECK: "vhlo.return_v1"(%[[VAL_12]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> ()
115+
// CHECK: }) : (!vhlo.tensor_v1<7x5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1>
116+
// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> ()
117+
// CHECK: }
118+
func.func @op_with_region_mixed_other_shlo_other(%arg0: tensor<7x5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> {
119+
%0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32>
120+
reducer(%arg2: tensor<5xf32>, %arg3: tensor<5xf32>) {
121+
%1 = arith.addf %arg2, %arg3 : tensor<5xf32>
122+
%2 = stablehlo.add %1, %arg2 : tensor<5xf32>
123+
%3 = arith.addf %2, %arg3 : tensor<5xf32>
124+
stablehlo.return %3 : tensor<5xf32>
125+
}
126+
return %0 : tensor<5xf32>
127+
}
128+
129+
// -----
130+
131+
// CHECK-LABEL: vhlo.func_v1 @op_with_region_mixed_shlo_other_shlo(
132+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<7x5x!vhlo.f32_v1>,
133+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<5x!vhlo.f32_v1>) {
134+
// CHECK: %[[VAL_2:.*]] = "vhlo.reduce_v1"(%[[VAL_0]], %[[VAL_1]]) <{dimensions = #{{.*}}<dense<0> : tensor<1xi64>>}> ({
135+
// CHECK: ^bb0(%[[VAL_3:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>, %[[VAL_4:.*]]: !vhlo.tensor_v1<5x!vhlo.f32_v1>):
136+
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32>
137+
// CHECK: %[[VAL_6:.*]] = "vhlo.abs_v1"(%[[VAL_3]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1>
138+
// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !vhlo.tensor_v1<5x!vhlo.f32_v1> to tensor<5xf32>
139+
// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_7]], %[[VAL_5]] : tensor<5xf32>
140+
// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : tensor<5xf32> to !vhlo.tensor_v1<5x!vhlo.f32_v1>
141+
// CHECK: %[[VAL_10:.*]] = "vhlo.abs_v1"(%[[VAL_9]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1>
142+
// CHECK: "vhlo.return_v1"(%[[VAL_10]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> ()
143+
// CHECK: }) : (!vhlo.tensor_v1<7x5x!vhlo.f32_v1>, !vhlo.tensor_v1<5x!vhlo.f32_v1>) -> !vhlo.tensor_v1<5x!vhlo.f32_v1>
144+
// CHECK: "vhlo.return_v1"(%[[VAL_2]]) : (!vhlo.tensor_v1<5x!vhlo.f32_v1>) -> ()
145+
// CHECK: }
146+
func.func @op_with_region_mixed_shlo_other_shlo(%arg0: tensor<7x5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> {
147+
%0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32>
148+
reducer(%arg2: tensor<5xf32>, %arg3: tensor<5xf32>) {
149+
%1 = stablehlo.abs %arg2 : tensor<5xf32>
150+
%2 = arith.addf %1, %arg3 : tensor<5xf32>
151+
%3 = stablehlo.abs %2 : tensor<5xf32>
152+
stablehlo.return %3 : tensor<5xf32>
153+
}
154+
return %0 : tensor<5xf32>
155+
}
156+
157+
// -----
158+
159+
// CHECK-LABEL: vhlo.func_v1 @stablehlo_in_other_op_region(
160+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.tensor_v1<2x!vhlo.f32_v1>,
161+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !vhlo.index_v1) -> (!vhlo.tensor_v1<2x!vhlo.f32_v1>) {
162+
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : !vhlo.tensor_v1<2x!vhlo.f32_v1> to tensor<2xf32>
163+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
164+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
165+
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
166+
// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
167+
// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_9:.*]] = %[[VAL_2]]) -> (tensor<2xf32>) {
168+
// CHECK: %[[VAL_10:.*]] = tensor.insert %[[VAL_6]] into %[[VAL_9]]{{\[}}%[[VAL_8]]] : tensor<2xf32>
169+
// CHECK: %[[VAL_11:.*]] = builtin.unrealized_conversion_cast %[[VAL_10]] : tensor<2xf32> to !vhlo.tensor_v1<2x!vhlo.f32_v1>
170+
// CHECK: %[[VAL_12:.*]] = "vhlo.add_v1"(%[[VAL_11]], %[[VAL_11]]) : (!vhlo.tensor_v1<2x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.f32_v1>) -> !vhlo.tensor_v1<2x!vhlo.f32_v1>
171+
// CHECK: %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !vhlo.tensor_v1<2x!vhlo.f32_v1> to tensor<2xf32>
172+
// CHECK: scf.yield %[[VAL_13]] : tensor<2xf32>
173+
// CHECK: }
174+
// CHECK: %[[VAL_14:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : tensor<2xf32> to !vhlo.tensor_v1<2x!vhlo.f32_v1>
175+
// CHECK: "vhlo.return_v1"(%[[VAL_14]]) : (!vhlo.tensor_v1<2x!vhlo.f32_v1>) -> ()
176+
// CHECK: }
177+
func.func @stablehlo_in_other_op_region(%arg0: tensor<2xf32>, %arg1: index) -> tensor<2xf32> {
178+
%c0 = arith.constant 0 : index
179+
%c1 = arith.constant 1 : index
180+
%c2 = arith.constant 2 : index
181+
%cst = arith.constant 0.0 : f32
182+
183+
%for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> {
184+
%new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32>
185+
%new_out_add = stablehlo.add %new_out, %new_out : tensor<2xf32>
186+
scf.yield %new_out_add : tensor<2xf32>
187+
}
188+
return %for : tensor<2xf32>
189+
}

0 commit comments

Comments
 (0)