@@ -133,8 +133,13 @@ static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) {
133133 return b.create <arith::IndexCastOp>(loc, b.getI64Type (), idx);
134134}
135135
136- static Value getDimOp (OpBuilder &b, Location loc, Value v, int dimension) {
137- return b.create <tensor::DimOp>(loc, v, dimension);
136+ static Value getDimOp (OpBuilder &b, Location loc, Value v, int dim) {
137+ if (auto tensorType = v.getType ().cast <RankedTensorType>()) {
138+ if (!tensorType.isDynamicDim (dim))
139+ return b.create <arith::ConstantOp>(
140+ loc, b.getIndexAttr (tensorType.getShape ()[dim]));
141+ }
142+ return b.create <tensor::DimOp>(loc, v, dim);
138143}
139144
140145static void checkDimEqualHelper (OpBuilder &b, Location loc, Value lhsDim,
@@ -2671,84 +2676,214 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
26712676 Location loc = op.getLoc ();
26722677 Value input = adaptor.self ();
26732678 auto inputType = input.getType ().cast <RankedTensorType>();
2679+ ArrayRef<int64_t > inputShape = inputType.getShape ();
26742680 int64_t inputRank = inputType.getRank ();
26752681 TypeConverter *typeConverter = getTypeConverter ();
26762682 auto resultType =
26772683 typeConverter->convertType (op.getType ()).cast <RankedTensorType>();
26782684 int64_t resultRank = resultType.getRank ();
2679- // When we only have expansion of dimensions in `aten.View`, the output
2680- // tensor rank will be strictly greater than the input tensor rank.
2681- // TODO: Handle the cases of `aten.View` op where,
2682- // 1. One or multiple dimensions are collapsed.
2683- // 2. Few dimensions are expanded and few other dimensions are collapsed.
2684- if (inputRank >= resultRank) {
2685+ // Currently, we only handle the expanding OR collapsing cases, we do not
2686+ // handle expanding And collapsing happening at the same time or cases where
2687+ // it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor.
2688+ // TODO: For the expanding And collapsing case, we will need to identify
2689+ // which dimensions are collapsing and which are expanding and do it in two
2690+ // steps.
2691+ // TODO: For neither collapsing nor expanding, we could find a intermediate
2692+ // shape to collapse and then expanded to the target shape. Like [2,3] =>
2693+ // [6] => [3, 2].
2694+ if (inputRank == resultRank)
26852695 return rewriter.notifyMatchFailure (
2686- op, " unimplemented: operand tensor rank should be strictly less than "
2687- " the desired output rank" );
2688- }
2696+ op, " unimplemented: the view op is neither expanding nor collapsing" );
2697+
2698+ if (resultRank == 0 )
2699+ return rewriter.notifyMatchFailure (op,
2700+ " result shape of rank 0 is invalid" );
2701+
2702+ // TODO: add support for case inputRank 0 expanded to size 1
2703+ if (inputRank == 0 )
2704+ return rewriter.notifyMatchFailure (
2705+ op, " unimplemented: input rank 0 is not supported" );
2706+
2707+ bool isCollapse = inputRank > resultRank ? true : false ;
2708+ int64_t collapsedRank = isCollapse ? resultRank : inputRank;
2709+ int64_t expandedRank = isCollapse ? inputRank : resultRank;
26892710
26902711 // Extract the desired output size as a list of integers. This list should
26912712 // have been created using the operation `torch.prim.ListConstruct`.
2692- SmallVector<Value> expectedSizeTorchInt ;
2693- if (!getListConstructElements (op.size (), expectedSizeTorchInt )) {
2713+ SmallVector<Value> outputSizeTorchInt ;
2714+ if (!getListConstructElements (op.size (), outputSizeTorchInt )) {
26942715 return rewriter.notifyMatchFailure (op,
2695- " unimplemented: the desired size is "
2716+ " unimplemented: the target size is "
26962717 " not constructed from ListConstruct" );
26972718 }
2698- SmallVector<Value> expectedSize = getTypeConvertedValues (
2699- rewriter, loc, typeConverter, expectedSizeTorchInt );
2700- if (resultRank != (int64_t )expectedSize .size ()) {
2719+ SmallVector<Value> outputSizeInt = getTypeConvertedValues (
2720+ rewriter, loc, typeConverter, outputSizeTorchInt );
2721+ if (resultRank != (int64_t )outputSizeInt .size ()) {
27012722 return rewriter.notifyMatchFailure (
27022723 op, " desired size list length mismatches with the result type rank" );
27032724 }
2725+ SmallVector<Value> inputSizeTorchInt = getTensorSizes (rewriter, loc, input);
2726+ ArrayRef<Value> expandedShapeTorchInt =
2727+ llvm::makeArrayRef (isCollapse ? inputSizeTorchInt : outputSizeInt);
2728+ ArrayRef<Value> collapsedShapeTorchInt =
2729+ llvm::makeArrayRef (isCollapse ? outputSizeInt : inputSizeTorchInt);
27042730
2705- // Check if the `aten.View` can be legalized to `linalg.TensorExpandShape`.
2706- // It only handles the case of static dimension expansion. If the dimension
2707- // is dynamic, it must not be expanded/splitted.
2708- // TODO: Handle the case of dynamic dimension expansion.
2709- SmallVector<ReassociationIndices> reassociation (inputRank);
2710- SmallVector<int64_t > resultShape;
2711- int64_t j = 0 ;
2712- for (auto i : llvm::seq<int64_t >(0 , inputRank)) {
2713- if (inputType.isDynamicDim (i)) {
2714- Value dim = getDimOp (rewriter, loc, input, i);
2715- if (j >= resultRank) {
2716- return rewriter.notifyMatchFailure (
2717- op, " desired size is not compatible with the input tensor size" );
2731+ // Iterate through the view op size list to do the following:
2732+ //
2733+ // 1. Combine output size list and input tensor type info to get the most
2734+ // static outputShape.
2735+ //
2736+ // 2. Fill in the reassociation for size list item where the output dim size
2737+ // is got from `torch.aten.size.int(inputTensor, inputDim)`. We naively
2738+ // assume this means the corresponding dimension is not expanded or
2739+ // collapsed. Note this may technically not always be true.
2740+ // TODO: think of a way better way to at least detect when this assumption
2741+ // is violated.
2742+ SmallVector<int64_t > outputShape (resultRank, kUnknownSize );
2743+ SmallVector<ReassociationIndices> reassociation (collapsedRank);
2744+ for (auto en : llvm::enumerate (outputSizeTorchInt)) {
2745+ int64_t inputDim;
2746+ int64_t outputDim = en.index ();
2747+ // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
2748+ if (matchPattern (en.value (),
2749+ m_TorchTensorSizeInt (op.self (), &inputDim))) {
2750+ auto collapsedDim = isCollapse ? outputDim : inputDim;
2751+ auto expandedDim = isCollapse ? inputDim : outputDim;
2752+ reassociation[collapsedDim].push_back (expandedDim);
2753+ if (!inputType.isDynamicDim (inputDim)) {
2754+ outputShape[outputDim] = inputShape[inputDim];
2755+ continue ;
27182756 }
2719- checkDimEqualHelper (rewriter, loc, dim, expectedSize[j]);
2720- reassociation[i].push_back (j++);
2721- resultShape.push_back (kUnknownSize );
2722- } else {
2723- int64_t expandedDim = inputType.getDimSize (i);
2724- int64_t outputDim;
2725- // A do-while loop is used here to handle the cases where the input
2726- // tensor has a dimension of size 1.
2727- do {
2728- if (j >= resultRank ||
2729- !matchPattern (expectedSizeTorchInt[j],
2730- m_TorchConstantInt (&outputDim)) ||
2731- expandedDim % outputDim != 0 ) {
2757+ }
2758+
2759+ int64_t size;
2760+ if (matchPattern (en.value (), m_TorchConstantInt (&size)))
2761+ outputShape[outputDim] = size;
2762+ }
2763+
2764+ SmallVector<int64_t > collapsedShape =
2765+ isCollapse ? outputShape : llvm::to_vector (inputShape);
2766+ SmallVector<int64_t > expandedShape =
2767+ isCollapse ? llvm::to_vector (inputShape) : outputShape;
2768+
2769+ // The while loop does the following:
2770+ // 1. Fill in the reassociation indices for dimensions that are expanded.
2771+ // Check the interval dimensions between two unchanged dims in the
2772+ // collapsedShape. If the interval is size 1, associate all the dims
2773+ // in the expandedShape shape until the next unchanged dim. If the interval
2774+ // is larger than size 1, figure out the associations with assumptions that
2775+ // dynamic dimensions are not splitted.
2776+ // 2. Set collapsedShape and expandedShape following the requirements by
2777+ // tensor.expand_shape verification code:
2778+ // a. As long as one or more of the related dimensions in the expanded
2779+ // shape is dynamic the collapsed dimension is dynamic.
2780+ // b. If all of the related dimensions are static, the collapsed
2781+ // dimension must be static. In other words, if a collapsed dimension is
2782+ // dynamic, at least one of the related dimensions need to be dynamic.
2783+ int64_t collapsedDim = 0 , expandedDim = 0 ;
2784+ while (collapsedDim < collapsedRank && expandedDim < expandedRank) {
2785+ // Not empty means the associations has been filled in and the dimension
2786+ // is unchanged.
2787+ if (!reassociation[collapsedDim].empty ()) {
2788+ if (expandedDim != reassociation[collapsedDim][0 ])
2789+ return op.emitOpError (" Unsupported: expanded dims are off from the "
2790+ " expected dim got from reassociation" );
2791+ collapsedDim++;
2792+ expandedDim++;
2793+ continue ;
2794+ }
2795+
2796+ // Collect the dims that are collapsed until hitting the next dim that's
2797+ // unchanged.
2798+ SmallVector<int64_t > collapsedDims;
2799+ while (collapsedDim < collapsedRank &&
2800+ reassociation[collapsedDim].empty ()) {
2801+ collapsedDims.push_back (collapsedDim);
2802+ collapsedDim++;
2803+ }
2804+ // the next reassociation is for a dim that's unchanged.
2805+ int64_t expandedDimNext = collapsedDim != collapsedRank
2806+ ? reassociation[collapsedDim][0 ]
2807+ : expandedRank;
2808+ if (collapsedDims.size () == 1 ) {
2809+ int64_t collapsedDimSize = 1 ;
2810+ int64_t collapsedDim = collapsedDims[0 ];
2811+ for (auto i : llvm::seq<int64_t >(expandedDim, expandedDimNext)) {
2812+ reassociation[collapsedDim].push_back (i);
2813+ if (collapsedDimSize == kUnknownSize )
2814+ continue ;
2815+
2816+ int64_t expandedDimSize = expandedShape[i];
2817+ if (expandedDimSize == kUnknownSize ) {
2818+ collapsedDimSize = kUnknownSize ;
2819+ continue ;
2820+ }
2821+ collapsedDimSize *= expandedShape[i];
2822+ }
2823+ // To meet both requirements from tensor.expand_shape verification code.
2824+ collapsedShape[collapsedDim] = collapsedDimSize;
2825+ expandedDim = expandedDimNext;
2826+ continue ;
2827+ }
2828+
2829+ // collpasedDims are expanded to [expandedDim, expandedDimNext)
2830+ if (expandedDimNext - expandedDim < (int64_t )collapsedDims.size ())
2831+ op.emitError (" unimplemented: mixed of expanding and collapsing "
2832+ " operations for view" );
2833+ for (auto collapsedDim : collapsedDims) {
2834+ if (collapsedShape[collapsedDim] == kUnknownSize ) {
2835+ if (expandedDim >= expandedDimNext) {
27322836 return rewriter.notifyMatchFailure (
2733- op, " total number of elements mismatch in the expansion" );
2837+ op,
2838+ " desired size is not compatible with the input tensor size" );
27342839 }
2735- reassociation[i].push_back (j++);
2736- resultShape.push_back (outputDim);
2737- expandedDim /= outputDim;
2738- } while (expandedDim != 1 );
2840+ checkDimEqualHelper (rewriter, loc,
2841+ collapsedShapeTorchInt[collapsedDim],
2842+ expandedShapeTorchInt[expandedDim]);
2843+ // To meet the second requirement from tensor.expand_shape
2844+ // verification code.
2845+ expandedShape[expandedDim] = kUnknownSize ;
2846+ reassociation[collapsedDim].push_back (expandedDim++);
2847+ } else {
2848+ int64_t remainingSizeToExpand = collapsedShape[collapsedDim];
2849+ // A do-while loop is used here to handle the cases where the
2850+ // collapsed shape tensor has a dimension of size 1.
2851+ do {
2852+ int64_t expandedDimSize = expandedShape[expandedDim];
2853+ if (expandedDim >= expandedDimNext ||
2854+ expandedShape[expandedDim] == kUnknownSize ||
2855+ remainingSizeToExpand % expandedDimSize != 0 ) {
2856+ return rewriter.notifyMatchFailure (
2857+ op, " total number of elements mismatch in the expansion" );
2858+ }
2859+ reassociation[collapsedDim].push_back (expandedDim++);
2860+ remainingSizeToExpand /= expandedDimSize;
2861+ } while (remainingSizeToExpand != 1 );
2862+ }
27392863 }
27402864 }
2741- // Make sure that the splitted dimensions have the same number of elements
2742- // as the dimension got splitted from.
2743- if (j != resultRank)
2744- return rewriter.notifyMatchFailure (
2745- op, " desired size is not compatible with the input tensor size" );
27462865
2747- Type expandType =
2748- RankedTensorType::get (resultShape, resultType.getElementType ());
2749- Value expandOp = rewriter.create <linalg::TensorExpandShapeOp>(
2750- loc, expandType, adaptor.self (), reassociation);
2751- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, expandOp);
2866+ if (collapsedDim != collapsedRank || expandedDim != expandedRank)
2867+ return rewriter.notifyMatchFailure (op, " view shape is not supported" );
2868+ Type adjustedResultType =
2869+ RankedTensorType::get (isCollapse ? collapsedShape : expandedShape,
2870+ resultType.getElementType ());
2871+ Type adjustedInputType =
2872+ RankedTensorType::get (isCollapse ? expandedShape : collapsedShape,
2873+ resultType.getElementType ());
2874+ Value castedInput =
2875+ rewriter.create <tensor::CastOp>(loc, adjustedInputType, input);
2876+ Value result =
2877+ isCollapse
2878+ ? rewriter
2879+ .create <linalg::TensorCollapseShapeOp>(
2880+ loc, adjustedResultType, castedInput, reassociation)
2881+ .result ()
2882+ : rewriter
2883+ .create <linalg::TensorExpandShapeOp>(
2884+ loc, adjustedResultType, castedInput, reassociation)
2885+ .result ();
2886+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
27522887 return success ();
27532888 }
27542889};
0 commit comments