Skip to content

Commit db551e3

Browse files
Eugene Burmakocopybara-github
authored andcommitted
Integrate StableHLO at openxla/stablehlo@52b6d47
Manual changes: * AssemblyFormat.cpp: reverted recent manual changes to the file to deal with LLVM deprecations, sent them to upstream StableHLO as a pull request: openxla/stablehlo#915. * TypeInference.cpp: still need a patch for inferGatherOp; I've opened a ticket to look into that: openxla/stablehlo#914. PiperOrigin-RevId: 501451535
1 parent ee61655 commit db551e3

File tree

2 files changed

+3
-104
lines changed

2 files changed

+3
-104
lines changed
Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,7 @@
1-
diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h
2-
--- stablehlo/stablehlo/dialect/AssemblyFormat.h
3-
+++ stablehlo/stablehlo/dialect/AssemblyFormat.h
4-
@@ -59,7 +59,7 @@
5-
OpTypes... types) {
6-
static_assert(sizeof...(types) > 0); // Must be non empty, must have result
7-
SmallVector<Type> typesVec{types...};
8-
- ArrayRef<Type> typesRef = makeArrayRef(typesVec);
9-
+ ArrayRef<Type> typesRef = ArrayRef(typesVec);
10-
return detail::printSameOperandsAndResultTypeImpl(
11-
p, op, typesRef.drop_back(1), typesRef.back());
12-
}
13-
@@ -69,7 +69,7 @@
14-
OpTypes&... types) {
15-
static_assert(sizeof...(types) > 0); // Must be non empty, must have result
16-
SmallVector<Type*> typesVec{&types...};
17-
- ArrayRef<Type*> typesRef = makeArrayRef(typesVec);
18-
+ ArrayRef<Type*> typesRef = ArrayRef(typesVec);
19-
return detail::parseSameOperandsAndResultTypeImpl(
20-
parser, typesRef.drop_back(1), *typesRef.back());
21-
}
221
diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
232
--- stablehlo/stablehlo/dialect/TypeInference.cpp
243
+++ stablehlo/stablehlo/dialect/TypeInference.cpp
25-
@@ -2184,7 +2184,9 @@
4+
@@ -2186,7 +2186,9 @@
265
}
276

287
auto getSliceDim = [&sliceSizes](int64_t index) -> int64_t {
@@ -33,84 +12,4 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
3312
};
3413

3514
return inferGatherReturnTypeComponents(
36-
diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp b/stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp
37-
--- stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp
38-
+++ stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp
39-
@@ -27,9 +27,9 @@
40-
const int64_t *insertedWindowDims, intptr_t nScatteredDimsToOperandDims,
41-
const int64_t *scatteredDimsToOperandDims, int64_t indexVectorDim) {
42-
return wrap(mlir::stablehlo::ScatterDimensionNumbersAttr::get(
43-
- unwrap(ctx), llvm::makeArrayRef(updateWindowDims, nUpdateWindowDims),
44-
- llvm::makeArrayRef(insertedWindowDims, nInsertedWindowDims),
45-
- llvm::makeArrayRef(scatteredDimsToOperandDims,
46-
+ unwrap(ctx), llvm::ArrayRef(updateWindowDims, nUpdateWindowDims),
47-
+ llvm::ArrayRef(insertedWindowDims, nInsertedWindowDims),
48-
+ llvm::ArrayRef(scatteredDimsToOperandDims,
49-
nScatteredDimsToOperandDims),
50-
indexVectorDim));
51-
}
52-
@@ -99,9 +99,9 @@
53-
intptr_t nStartIndexMap, const int64_t *startIndexMap,
54-
int64_t indexVectorDim) {
55-
return wrap(mlir::stablehlo::GatherDimensionNumbersAttr::get(
56-
- unwrap(ctx), llvm::makeArrayRef(offsetDims, nOffsetDims),
57-
- llvm::makeArrayRef(collapsedSliceDims, nCollapsedSliceDims),
58-
- llvm::makeArrayRef(startIndexMap, nStartIndexMap), indexVectorDim));
59-
+ unwrap(ctx), llvm::ArrayRef(offsetDims, nOffsetDims),
60-
+ llvm::ArrayRef(collapsedSliceDims, nCollapsedSliceDims),
61-
+ llvm::ArrayRef(startIndexMap, nStartIndexMap), indexVectorDim));
62-
}
63-
64-
bool stablehloAttributeIsAGatherDimensionNumbers(MlirAttribute attr) {
65-
@@ -170,10 +170,10 @@
66-
const int64_t *rhsContractingDimensions) {
67-
return wrap(mlir::stablehlo::DotDimensionNumbersAttr::get(
68-
unwrap(ctx),
69-
- llvm::makeArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions),
70-
- llvm::makeArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions),
71-
- llvm::makeArrayRef(lhsContractingDimensions, nLhsContractingDimensions),
72-
- llvm::makeArrayRef(rhsContractingDimensions, nRhsContractingDimensions)));
73-
+ llvm::ArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions),
74-
+ llvm::ArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions),
75-
+ llvm::ArrayRef(lhsContractingDimensions, nLhsContractingDimensions),
76-
+ llvm::ArrayRef(rhsContractingDimensions, nRhsContractingDimensions)));
77-
}
78-
79-
bool stablehloAttributeIsADotDimensionNumbers(MlirAttribute attr) {
80-
@@ -253,11 +253,11 @@
81-
intptr_t nOutputSpatialDimensions, const int64_t *outputSpatialDimensions) {
82-
return wrap(mlir::stablehlo::ConvDimensionNumbersAttr::get(
83-
unwrap(ctx), inputBatchDimension, inputFeatureDimension,
84-
- llvm::makeArrayRef(inputSpatialDimensions, nInputSpatialDimensions),
85-
+ llvm::ArrayRef(inputSpatialDimensions, nInputSpatialDimensions),
86-
kernelInputFeatureDimension, kernelOutputFeatureDimension,
87-
- llvm::makeArrayRef(kernelSpatialDimensions, nKernelSpatialDimensions),
88-
+ llvm::ArrayRef(kernelSpatialDimensions, nKernelSpatialDimensions),
89-
outputBatchDimension, outputFeatureDimension,
90-
- llvm::makeArrayRef(outputSpatialDimensions, nOutputSpatialDimensions)));
91-
+ llvm::ArrayRef(outputSpatialDimensions, nOutputSpatialDimensions)));
92-
}
93-
94-
bool stablehloAttributeIsAConvDimensionNumbers(MlirAttribute attr) {
95-
@@ -360,9 +360,9 @@
96-
const int64_t *outputTupleIndices, int64_t operandIndex,
97-
intptr_t nOperandTupleIndices, const int64_t *operandTupleIndices) {
98-
return wrap(mlir::stablehlo::OutputOperandAliasAttr::get(
99-
- unwrap(ctx), llvm::makeArrayRef(outputTupleIndices, nOutputTupleIndices),
100-
+ unwrap(ctx), llvm::ArrayRef(outputTupleIndices, nOutputTupleIndices),
101-
operandIndex,
102-
- llvm::makeArrayRef(operandTupleIndices, nOperandTupleIndices)));
103-
+ llvm::ArrayRef(operandTupleIndices, nOperandTupleIndices)));
104-
}
105-
106-
bool stablehloAttributeIsAOutputOperandAlias(MlirAttribute attr) {
107-
@@ -586,7 +586,7 @@
108-
MlirAttribute stablehloTypeExtensionsGet(MlirContext ctx, intptr_t nBounds,
109-
const int64_t *bounds) {
110-
return wrap(mlir::stablehlo::TypeExtensionsAttr::get(
111-
- unwrap(ctx), llvm::makeArrayRef(bounds, nBounds)));
112-
+ unwrap(ctx), llvm::ArrayRef(bounds, nBounds)));
113-
}
114-
115-
bool stablehloAttributeIsTypeExtensions(MlirAttribute attr) {
11615

third_party/stablehlo/workspace.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
44

55
def repo():
66
# LINT.IfChange
7-
STABLEHLO_COMMIT = "395f63efee5f0427be1bf6dac36fba5db215069d"
8-
STABLEHLO_SHA256 = "14aaf1fcc29373e80cda002057d8433e61ecaf1e5980d2e6d4ffc745cb9733dd"
7+
STABLEHLO_COMMIT = "52b6d47e4db763708634415247ab47f4a6c180bd"
8+
STABLEHLO_SHA256 = "e94432be6e56bffd046ea44fc3cd577d1d1ca114940bbca716cc0e266db0da98"
99
# LINT.ThenChange(Google-internal path)
1010

1111
tf_http_archive(

0 commit comments

Comments
 (0)