Skip to content

Commit 7a2705d

Browse files
bjacobhanhanW
andauthored
Bump stablehlo to f7f8e4e35 and drop LLVM local reverts (#18668)
Continuing from @hanhanW 's #18659: Stablehlo cherry-picks: 1. openxla/stablehlo#2572 2. openxla/stablehlo#2573 Torch-mlir cherry-picks: 1. llvm/torch-mlir#3755 --------- Signed-off-by: hanhanW <[email protected]> Signed-off-by: Benoit Jacob <[email protected]> Co-authored-by: hanhanW <[email protected]>
1 parent d341128 commit 7a2705d

File tree

11 files changed

+23
-20
lines changed

11 files changed

+23
-20
lines changed

compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -491,15 +491,15 @@ struct AllGatherOpConversion final
491491
op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter);
492492

493493
// Get the collective element type attribute.
494-
auto resultType = cast<RankedTensorType>(op.getResult().getType());
494+
auto resultType = cast<RankedTensorType>(op.getResult(0).getType());
495495
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
496496
IREE::Flow::getCollectiveElementTypeAttr(resultType);
497497
if (!elementTypeAttr) {
498498
return rewriter.notifyMatchFailure(
499499
op, "unsupported element type for collective op");
500500
}
501501
uint64_t allGatherDim = op.getAllGatherDim();
502-
Value gatherInput = adaptor.getOperand();
502+
Value gatherInput = adaptor.getOperands()[0];
503503
SmallVector<int64_t> gatherResultShape(resultType.getShape());
504504

505505
// When all_gather_dim != 0, we need to transpose between 0 and
@@ -513,7 +513,7 @@ struct AllGatherOpConversion final
513513
// Create an empty tensor for the result.
514514
Value target = rewriter.create<tensor::EmptyOp>(
515515
loc, gatherResultShape,
516-
getElementTypeOrSelf(adaptor.getOperand().getType()));
516+
getElementTypeOrSelf(adaptor.getOperands()[0].getType()));
517517
Value gatherResult = rewriter.create<IREE::Flow::CollectiveAllGatherOp>(
518518
op.getLoc(), elementTypeAttr, target, gatherInput, channel);
519519

@@ -585,7 +585,7 @@ struct AllReduceOpConversion final
585585
auto reductionOpAttr =
586586
IREE::Flow::CollectiveReductionOpAttr::get(op.getContext(), *redOp);
587587

588-
auto inputType = cast<RankedTensorType>(op.getOperand().getType());
588+
auto inputType = cast<RankedTensorType>(op.getOperand(0).getType());
589589

590590
// Get the collective element type attribute.
591591
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
@@ -597,10 +597,11 @@ struct AllReduceOpConversion final
597597
// Create an empty tensor for the result.
598598
ArrayRef<int64_t> inputShape = inputType.getShape();
599599
Value target = rewriter.create<tensor::EmptyOp>(
600-
loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType()));
600+
loc, inputShape,
601+
getElementTypeOrSelf(adaptor.getOperands()[0].getType()));
601602
auto allReduceOp = rewriter.create<IREE::Flow::CollectiveAllReduceOp>(
602603
op.getLoc(), reductionOpAttr, elementTypeAttr, target,
603-
adaptor.getOperand(), channel);
604+
adaptor.getOperands()[0], channel);
604605
rewriter.replaceOp(op, allReduceOp.getResult());
605606
return success();
606607
}
@@ -676,7 +677,7 @@ struct AllToAllOpConversion final
676677
op.getReplicaGroups(), /*useGlobalDeviceIds=*/std::nullopt, rewriter);
677678

678679
// Get the collective element type attribute.
679-
auto resultType = cast<RankedTensorType>(op.getType());
680+
auto resultType = cast<RankedTensorType>(op.getType(0));
680681
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
681682
IREE::Flow::getCollectiveElementTypeAttr(resultType);
682683
if (!elementTypeAttr) {
@@ -687,7 +688,7 @@ struct AllToAllOpConversion final
687688
uint64_t splitDim = op.getSplitDimension();
688689
uint64_t concatDim = op.getConcatDimension();
689690
uint64_t splitCount = op.getSplitCount();
690-
Value allToAllInput = adaptor.getOperand();
691+
Value allToAllInput = adaptor.getOperands()[0];
691692

692693
// When splitDim != 0, we need to transpose splitDim to 0 before and after
693694
// the all-to-all.

compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ struct GeneralDotRemoveBatch final
198198

199199
auto dot = rewriter.create<mlir::stablehlo::DotGeneralOp>(
200200
op.getLoc(), ty.clone(ty.getShape().drop_front()), lhs, rhs,
201-
newDimNumbers, op.getPrecisionConfigAttr());
201+
newDimNumbers, op.getPrecisionConfigAttr(), op.getAlgorithmAttr());
202202
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, ty,
203203
dot.getResult());
204204
return success();

compiler/plugins/input/StableHLO/Conversion/Preprocessing/EinsumToDotGeneral.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ struct EinsumToDotGeneralPattern final
141141
auto dotGeneralOp = rewriter.create<mlir::stablehlo::DotGeneralOp>(
142142
einsum.getLoc(), dotGeneralResultType, einsum.getLhs(), einsum.getRhs(),
143143
dimNumbers,
144-
/*precision_config=*/ArrayAttr{});
144+
/*precision_config=*/ArrayAttr{}, mlir::stablehlo::DotAlgorithmAttr{});
145145

146146
if (isNaturalOrder) {
147147
// The dot_general is already in an appropriate result order.

compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ struct TransposeReshapeGenericDotGeneral final
424424

425425
auto newOp = rewriter.create<mlir::stablehlo::DotGeneralOp>(
426426
op.getLoc(), newResultType, lhs, rhs, dimensionNumbers,
427-
op.getPrecisionConfigAttr());
427+
op.getPrecisionConfigAttr(), op.getAlgorithmAttr());
428428

429429
// Copy over unknown attributes as we currently rely on it to let user tune
430430
// lowering parameters.

compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ struct HouseholderReflectorRewriter final
200200
auto dotNums = mlir::stablehlo::DotDimensionNumbersAttr::get(
201201
b.getContext(), batch, batch, lhsContract, rhsContract);
202202
Value dot = b.create<mlir::stablehlo::DotGeneralOp>(
203-
householder0.getType(), args[0], householder, dotNums, nullptr);
203+
householder0.getType(), args[0], householder, dotNums, nullptr,
204+
mlir::stablehlo::DotAlgorithmAttr{});
204205
b.create<scf::YieldOp>(loc, dot);
205206
});
206207

compiler/plugins/input/StableHLO/Conversion/test/legalize_control_flow.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func.func @conditional_nested(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<
150150
func.func @case2(%arg0 : tensor<i32>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> {
151151

152152
// CHECK-NEXT: %[[VAL_3:.*]] = stablehlo.constant dense<0> : tensor<i32>
153-
// CHECK: %[[VAL_4:.*]] = stablehlo.compare EQ, %[[VAL_0]], %[[VAL_3]], NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
153+
// CHECK: %[[VAL_4:.*]] = stablehlo.compare EQ, %[[VAL_0]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
154154
// CHECK: %[[VAL_5:.*]] = tensor.extract %[[VAL_4]][] : tensor<i1>
155155
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (tensor<4xf32>) {
156156
%1 = "stablehlo.case"(%arg0) ({
@@ -180,7 +180,7 @@ func.func @case2(%arg0 : tensor<i32>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf3
180180
func.func @case3(%arg0 : tensor<i32>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>, %arg3 : tensor<4xf32>) -> tensor<4xf32> {
181181

182182
// CHECK-NEXT: %[[VAL_4:.*]] = stablehlo.constant dense<0> : tensor<i32>
183-
// CHECK: %[[VAL_5:.*]] = stablehlo.compare EQ, %[[VAL_0]], %[[VAL_4]], NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
183+
// CHECK: %[[VAL_5:.*]] = stablehlo.compare EQ, %[[VAL_0]], %[[VAL_4]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
184184
// CHECK: %[[VAL_6:.*]] = tensor.extract %[[VAL_5]][] : tensor<i1>
185185
// CHECK: %[[VAL_7:.*]] = scf.if %[[VAL_6]] -> (tensor<4xf32>) {
186186
%1 = "stablehlo.case"(%arg0) ({
@@ -191,7 +191,7 @@ func.func @case3(%arg0 : tensor<i32>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf3
191191

192192
// CHECK: } else {
193193
// CHECK-NEXT: %[[VAL_9:.*]] = stablehlo.constant dense<1> : tensor<i32>
194-
// CHECK: %[[VAL_10:.*]] = stablehlo.compare EQ, %[[VAL_0]], %[[VAL_9]], NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
194+
// CHECK: %[[VAL_10:.*]] = stablehlo.compare EQ, %[[VAL_0]], %[[VAL_9]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
195195
// CHECK: %[[VAL_11:.*]] = tensor.extract %[[VAL_10]][] : tensor<i1>
196196
// CHECK: %[[VAL_12:.*]] = scf.if %[[VAL_11]] -> (tensor<4xf32>) {
197197
}, {

compiler/src/iree/compiler/Tools/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ iree_compiler_cc_library(
113113
"@llvm-project//mlir:MLProgramDialect",
114114
"@llvm-project//mlir:MathDialect",
115115
"@llvm-project//mlir:MemRefDialect",
116+
"@llvm-project//mlir:QuantOps",
116117
"@llvm-project//mlir:SCFDialect",
117118
"@llvm-project//mlir:SCFToGPU",
118119
"@llvm-project//mlir:SCFTransforms",

compiler/src/iree/compiler/Tools/init_mlir_dialects.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
4141
#include "mlir/Dialect/PDL/IR/PDL.h"
4242
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
43-
#include "mlir/Dialect/Quant/QuantOps.h"
43+
#include "mlir/Dialect/Quant/IR/Quant.h"
4444
#include "mlir/Dialect/SCF/IR/SCF.h"
4545
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
4646
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@@ -81,7 +81,7 @@ inline void registerMlirDialects(DialectRegistry &registry) {
8181
pdl::PDLDialect,
8282
pdl_interp::PDLInterpDialect,
8383
scf::SCFDialect,
84-
quant::QuantizationDialect,
84+
quant::QuantDialect,
8585
spirv::SPIRVDialect,
8686
arm_neon::ArmNeonDialect,
8787
arm_sve::ArmSVEDialect,

third_party/llvm-project

Submodule llvm-project updated 39 files

third_party/stablehlo

Submodule stablehlo updated 6601 files

0 commit comments

Comments
 (0)