@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
711711 return nullptr ;
712712}
713713
714- static Value expandRank (PatternRewriter &rewriter, Location loc, Value tensor,
715- int64_t rank) {
716- // No need to expand if we are already at the desired rank
717- auto tensorType = dyn_cast<RankedTensorType>(tensor.getType ());
718- assert (tensorType && " expected a ranked tensor type" );
719- int64_t tensorRank = tensorType.getRank ();
720- int64_t numExtraDims = rank - tensorRank;
721- assert (numExtraDims >= 0 && " cannot expand tensor to a lower rank" );
722- if (!numExtraDims)
723- return tensor;
724-
725- // Compute reassociation indices
726- SmallVector<ReassociationIndices> reassociationIndices (tensorRank);
727- int64_t index = 0 ;
728- if (tensorRank != 0 ) {
729- for (index = 0 ; index <= numExtraDims; index++)
730- reassociationIndices[0 ].push_back (index);
731- for (size_t position = 1 ; position < reassociationIndices.size ();
732- position++)
733- reassociationIndices[position].push_back (index++);
734- }
735-
736- // Compute result type
737- SmallVector<int64_t > resultShape;
738- for (index = 0 ; index < numExtraDims; index++)
739- resultShape.push_back (1 );
740- for (auto size : tensorType.getShape ())
741- resultShape.push_back (size);
742- auto resultType =
743- RankedTensorType::get (resultShape, tensorType.getElementType ());
744-
745- // Emit 'tensor.expand_shape' op
746- return rewriter.create <tensor::ExpandShapeOp>(loc, resultType, tensor,
747- reassociationIndices);
748- }
749-
750- static SmallVector<Value> expandInputRanks (PatternRewriter &rewriter,
751- Location loc, ValueRange operands,
752- int64_t rank) {
753- return llvm::map_to_vector (operands, [&](Value operand) {
754- return expandRank (rewriter, loc, operand, rank);
755- });
756- }
757-
758714using IndexPool = DenseMap<int64_t , Value>;
759715
760716// Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
1036992 return success ();
1037993}
1038994
995+ static ValueRange getBroadcastableOperands (Operation *operation,
996+ ValueRange operands) {
997+ // Shift cannot broadcast
998+ if (isa<tosa::MulOp>(operation))
999+ return operands.take_front (2 );
1000+ // Input1_zp and output_zp cannot broadcast
1001+ if (isa<tosa::NegateOp>(operation))
1002+ return operands.take_front (1 );
1003+ return operands;
1004+ }
1005+
10391006static LogicalResult
10401007elementwiseMatchAndRewriteHelper (Operation *operation, ValueRange operands,
10411008 ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
10521019 // Lower operation
10531020 IndexPool indexPool;
10541021 auto loc = operation->getLoc ();
1055- auto rank =
1056- cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
1057- // For the mul op we need to avoid expanding the rank of the optional shift
1058- // input.
1059- auto operandsToExpand =
1060- isa<tosa::MulOp>(operation) ? operands.take_front (2 ) : operands;
1061-
1062- auto expandedOperands =
1063- expandInputRanks (rewriter, loc, operandsToExpand, rank);
1022+ auto operandsToBroadcast = getBroadcastableOperands (operation, operands);
10641023 auto [targetShape, masterOperands] =
1065- computeTargetShape (rewriter, loc, indexPool, expandedOperands);
1066- auto broadcastOperands = broadcastDynamicDimensions (
1067- rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
1024+ computeTargetShape (rewriter, loc, indexPool, operandsToBroadcast);
1025+ auto broadcastOperands =
1026+ broadcastDynamicDimensions (rewriter, loc, indexPool, operandsToBroadcast,
1027+ targetShape, masterOperands);
10681028 return emitElementwiseComputation (rewriter, loc, operation, broadcastOperands,
10691029 targetShape, converter);
10701030}
0 commit comments