@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
105105static MeshOp getMesh (Operation *op, ArrayRef<MeshSharding> operandShardings,
106106 ArrayRef<MeshSharding> resultShardings,
107107 SymbolTableCollection &symbolTable) {
108- for (const MeshSharding& sharding : operandShardings) {
108+ for (const MeshSharding & sharding : operandShardings) {
109109 if (sharding) {
110110 return mesh::getMesh (op, sharding.getMeshAttr (), symbolTable);
111111 }
112112 }
113113
114- for (const MeshSharding& sharding : resultShardings) {
114+ for (const MeshSharding & sharding : resultShardings) {
115115 if (sharding) {
116116 return mesh::getMesh (op, sharding.getMeshAttr (), symbolTable);
117117 }
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
129129// the original operand.
130130// The other processes would use the reduction operation neutral tensor.
131131static Value createDestinationPassingStyleInitOperand (
132- LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
133- MeshOp meshOp, ImplicitLocOpBuilder &builder) {
132+ LinalgOp op, int operandNumber, Value spmdizedOperand,
133+ ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134+ ImplicitLocOpBuilder &builder) {
134135 Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex (
135136 meshOp.getSymName (), reductionMeshAxes, builder);
136137 Value zero = builder.create <arith::ConstantIndexOp>(0 );
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
152153 builder.setInsertionPointToEnd (&ifOp.getElseRegion ().front ());
153154 SmallVector<OpFoldResult> shape =
154155 tensor::getMixedSizes (builder, builder.getLoc (), spmdizedOperand);
155- PartialReductionOpInterface partialReductionIface =
156- llvm::cast<PartialReductionOpInterface>(op.getOperation ());
157- assert (op->getNumResults () == 1 && " Multiple results not supported." );
158- FailureOr<SmallVector<Value>> reductionNeutralTensor =
159- partialReductionIface.generateInitialTensorForPartialReduction (
160- builder, builder.getLoc (), shape, {});
161- assert (succeeded (reductionNeutralTensor));
162- builder.create <scf::YieldOp>(reductionNeutralTensor.value ());
156+
157+ SmallVector<Operation *> combinerOps;
158+ matchReduction (op.getRegionOutputArgs (), operandNumber, combinerOps);
159+ assert (combinerOps.size () == 1 );
160+ std::optional<TypedAttr> neutralEl =
161+ arith::getNeutralElement (combinerOps[0 ]);
162+
163+ Value init = builder.create <tensor::EmptyOp>(op.getLoc (), shape,
164+ neutralEl.value ().getType ());
165+ Value constant =
166+ builder.create <arith::ConstantOp>(op.getLoc (), neutralEl.value ());
167+ Value fill = builder.create <linalg::FillOp>(op.getLoc (), constant, init)
168+ .getResult (0 );
169+
170+ builder.create <scf::YieldOp>(fill);
163171 }
164172 return ifOp.getResult (0 );
165173}
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
178186 Value spmdizedInitOperand =
179187 spmdizationMap.lookup (op->getOperands ()[operandIdx]);
180188 newOperands[operandIdx] = createDestinationPassingStyleInitOperand (
181- op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
189+ op, 0 , spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
182190 return newOperands;
183191}
184192
0 commit comments