Skip to content

Commit 8d9a84b

Browse files
authored
1 parent 81d2ca2 commit 8d9a84b

File tree

8 files changed

+11
-10
lines changed

8 files changed

+11
-10
lines changed

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "799e9053641a6478d3144866a97737b37b87c260"
20+
LLVM_COMMIT = "69f59d59cb02c06f1fac93ea5b19c2df9a684109"
2121

22-
LLVM_SHA256 = "be33f1f9f20da6bd744d62356bf469e906e3b5f5e9cba2af6ee6418cee49f1f3"
22+
LLVM_SHA256 = "2fd8dcec1da1c7166d58918d5f6330856edb37351248a5947661055313bb5d46"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
799e9053641a6478d3144866a97737b37b87c260
1+
69f59d59cb02c06f1fac93ea5b19c2df9a684109

stablehlo/testdata/fft_complex128_14_15_0_17.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module @jit_main attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
1212
return %2 : tensor<14x15x0x33xf64>
1313
}
1414
func.func private @inputs() -> (tensor<14x15x0x17xcomplex<f64>> {mhlo.layout_mode = "default"}) {
15-
%cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex<f64>>
15+
%cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex<f64>>
1616
return %cst : tensor<14x15x0x17xcomplex<f64>>
1717
}
1818
func.func private @expected() -> (tensor<14x15x0x33xf64> {mhlo.layout_mode = "default"}) {

stablehlo/testdata/fft_complex64_14_15_0_17.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module @jit_main attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
1212
return %2 : tensor<14x15x0x33xf32>
1313
}
1414
func.func private @inputs() -> (tensor<14x15x0x17xcomplex<f32>> {mhlo.layout_mode = "default"}) {
15-
%cst = stablehlo.constant dense<> : tensor<14x15x0x17xcomplex<f32>>
15+
%cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x17xcomplex<f32>>
1616
return %cst : tensor<14x15x0x17xcomplex<f32>>
1717
}
1818
func.func private @expected() -> (tensor<14x15x0x33xf32> {mhlo.layout_mode = "default"}) {

stablehlo/testdata/fft_float32_14_15_0_17.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module @jit_main attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
1616
return %cst : tensor<14x15x0x17xf32>
1717
}
1818
func.func private @expected() -> (tensor<14x15x0x9xcomplex<f32>> {mhlo.layout_mode = "default"}) {
19-
%cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex<f32>>
19+
%cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex<f32>>
2020
return %cst : tensor<14x15x0x9xcomplex<f32>>
2121
}
2222
}

stablehlo/testdata/fft_float64_14_15_0_17.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module @jit_main attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
1616
return %cst : tensor<14x15x0x17xf64>
1717
}
1818
func.func private @expected() -> (tensor<14x15x0x9xcomplex<f64>> {mhlo.layout_mode = "default"}) {
19-
%cst = stablehlo.constant dense<> : tensor<14x15x0x9xcomplex<f64>>
19+
%cst = stablehlo.constant dense<(0.0, 0.0)> : tensor<14x15x0x9xcomplex<f64>>
2020
return %cst : tensor<14x15x0x9xcomplex<f64>>
2121
}
2222
}

stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,8 +1539,8 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
15391539

15401540
void populateStablehloHloImportCanonicalizationPatterns(
15411541
MLIRContext *context, RewritePatternSet *patterns) {
1542-
patterns->add<TupleIsRepacking, TupleIsUnpacked, WhileOpImplicitCapture>(
1543-
context);
1542+
patterns->add<ReshapeIsNoop, TupleIsRepacking, TupleIsUnpacked,
1543+
WhileOpImplicitCapture>(context);
15441544
}
15451545

15461546
std::unique_ptr<Pass> createStablehloAggressiveSimplificationPass(

stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def : Pat<(StableHLO_ReshapeOp:$reshape (StableHLO_ReshapeOp $operand)),
366366
(StableHLO_ReshapeOpWithShape $reshape, $operand)>;
367367

368368
// Pattern: reshape(X, [X.shape]) -> X
369-
def : Pat<(StableHLO_ReshapeOp:$reshape $operand),
369+
def ReshapeIsNoop
370+
: Pat<(StableHLO_ReshapeOp:$reshape $operand),
370371
(replaceWithValue $operand),
371372
[(TypesEqual $reshape, $operand)]>;
372373

0 commit comments

Comments
 (0)