Skip to content

Commit d8ba681

Browse files
committed
Lower aten::view with linalg.collapse and linalg.expand
We only handle the expanding OR collapsing cases, we do not handle expanding And collapsing happening at the same time or cases where it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor. It's assumed that if a shape list element is got from `aten.size(tensor, dim)` the corresponding dim is not splitted or collapsed. This assumption makes it easier to deal with dynamic shapes.
1 parent bc9abbc commit d8ba681

File tree

3 files changed

+278
-59
lines changed

3 files changed

+278
-59
lines changed

e2e_testing/torchscript/view.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
1010

1111
# ==============================================================================
12-
1312
class ViewExpandModule(torch.nn.Module):
1413
def __init__(self):
1514
super().__init__()
@@ -46,3 +45,60 @@ def forward(self, a):
4645
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
4746
module.forward(tu.rand(2, 4, 30, 384))
4847

48+
49+
# ==============================================================================
50+
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
51+
def __init__(self):
52+
super().__init__()
53+
54+
@export
55+
@annotate_args([
56+
None,
57+
([-1, -1, -1], torch.float32, True),
58+
])
59+
60+
def forward(self, a):
61+
return a.view(a.size(0), a.size(1), 12, 32)
62+
63+
@register_test_case(module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule())
64+
def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
65+
module.forward(tu.rand(2, 4, 384))
66+
67+
# ==============================================================================
68+
class ViewCollapseModule(torch.nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
72+
@export
73+
@annotate_args([
74+
None,
75+
([-1, -1], torch.float32, True),
76+
])
77+
78+
def forward(self, a):
79+
return a.view(8)
80+
81+
@register_test_case(module_factory=lambda: ViewCollapseModule())
82+
def ViewCollapseModule_basic(module, tu: TestUtils):
83+
module.forward(tu.rand(2, 4))
84+
85+
86+
# ==============================================================================
87+
class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
91+
@export
92+
@annotate_args([
93+
None,
94+
([-1, -1, -1, -1, -1, -1], torch.float32, True),
95+
([], torch.int64, True),
96+
([], torch.int64, True),
97+
])
98+
99+
def forward(self, a, b, c):
100+
return a.view(a.size(0), int(b), int(c), a.size(3), 384)
101+
102+
@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
103+
def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
104+
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))

include/torch-mlir/Dialect/Torch/IR/TorchOps.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,34 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
108108
return detail::torch_list_construct_op_binder(bind_values);
109109
}
110110

111+
namespace detail {
112+
/// Matches the expected tensor and dim from `torch.aten.size.int`.
113+
struct torch_tensor_size_int_op_binder {
114+
int64_t *dim;
115+
Value tensor;
116+
117+
/// Creates a matcher instance that binds the value to dim if match succeeds.
118+
torch_tensor_size_int_op_binder(Value tensor, int64_t *dim)
119+
: dim(dim), tensor(tensor) {}
120+
121+
bool match(Operation *op) {
122+
if (auto atenSizeIntOp = dyn_cast<Torch::AtenSizeIntOp>(op)) {
123+
if (atenSizeIntOp.self() == tensor) {
124+
if (matchPattern(atenSizeIntOp.dim(), m_TorchConstantInt(dim)))
125+
return true;
126+
}
127+
}
128+
return false;
129+
}
130+
};
131+
} // namespace detail
132+
133+
/// Matches the tensor and dim of `torch.size.int`.
134+
inline detail::torch_tensor_size_int_op_binder
135+
m_TorchTensorSizeInt(Value tensor, int64_t *dim) {
136+
return detail::torch_tensor_size_int_op_binder(tensor, dim);
137+
}
138+
111139
/// Create code to copy `tensor` to type `newType`.
112140
///
113141
/// This involves two independent steps, which we keep orthogonal in our

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 193 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,13 @@ static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) {
133133
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
134134
}
135135

136-
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
137-
return b.create<tensor::DimOp>(loc, v, dimension);
136+
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
137+
if (auto tensorType = v.getType().cast<RankedTensorType>()) {
138+
if (!tensorType.isDynamicDim(dim))
139+
return b.create<arith::ConstantOp>(
140+
loc, b.getIndexAttr(tensorType.getShape()[dim]));
141+
}
142+
return b.create<tensor::DimOp>(loc, v, dim);
138143
}
139144

140145
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
@@ -2671,84 +2676,214 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
26712676
Location loc = op.getLoc();
26722677
Value input = adaptor.self();
26732678
auto inputType = input.getType().cast<RankedTensorType>();
2679+
ArrayRef<int64_t> inputShape = inputType.getShape();
26742680
int64_t inputRank = inputType.getRank();
26752681
TypeConverter *typeConverter = getTypeConverter();
26762682
auto resultType =
26772683
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
26782684
int64_t resultRank = resultType.getRank();
2679-
// When we only have expansion of dimensions in `aten.View`, the output
2680-
// tensor rank will be strictly greater than the input tensor rank.
2681-
// TODO: Handle the cases of `aten.View` op where,
2682-
// 1. One or multiple dimensions are collapsed.
2683-
// 2. Few dimensions are expanded and few other dimensions are collapsed.
2684-
if (inputRank >= resultRank) {
2685+
// Currently, we only handle the expanding OR collapsing cases, we do not
2686+
// handle expanding And collapsing happening at the same time or cases where
2687+
// it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor.
2688+
// TODO: For the expanding And collapsing case, we will need to identify
2689+
// which dimensions are collapsing and which are expanding and do it in two
2690+
// steps.
2691+
// TODO: For neither collapsing nor expanding, we could find a intermediate
2692+
// shape to collapse and then expanded to the target shape. Like [2,3] =>
2693+
// [6] => [3, 2].
2694+
if (inputRank == resultRank)
26852695
return rewriter.notifyMatchFailure(
2686-
op, "unimplemented: operand tensor rank should be strictly less than "
2687-
"the desired output rank");
2688-
}
2696+
op, "unimplemented: the view op is neither expanding nor collapsing");
2697+
2698+
if (resultRank == 0)
2699+
return rewriter.notifyMatchFailure(op,
2700+
"result shape of rank 0 is invalid");
2701+
2702+
// TODO: add support for case inputRank 0 expanded to size 1
2703+
if (inputRank == 0)
2704+
return rewriter.notifyMatchFailure(
2705+
op, "unimplemented: input rank 0 is not supported");
2706+
2707+
bool isCollapse = inputRank > resultRank ? true : false;
2708+
int64_t collapsedRank = isCollapse ? resultRank : inputRank;
2709+
int64_t expandedRank = isCollapse ? inputRank : resultRank;
26892710

26902711
// Extract the desired output size as a list of integers. This list should
26912712
// have been created using the operation `torch.prim.ListConstruct`.
2692-
SmallVector<Value> expectedSizeTorchInt;
2693-
if (!getListConstructElements(op.size(), expectedSizeTorchInt)) {
2713+
SmallVector<Value> outputSizeTorchInt;
2714+
if (!getListConstructElements(op.size(), outputSizeTorchInt)) {
26942715
return rewriter.notifyMatchFailure(op,
2695-
"unimplemented: the desired size is "
2716+
"unimplemented: the target size is "
26962717
"not constructed from ListConstruct");
26972718
}
2698-
SmallVector<Value> expectedSize = getTypeConvertedValues(
2699-
rewriter, loc, typeConverter, expectedSizeTorchInt);
2700-
if (resultRank != (int64_t)expectedSize.size()) {
2719+
SmallVector<Value> outputSizeInt = getTypeConvertedValues(
2720+
rewriter, loc, typeConverter, outputSizeTorchInt);
2721+
if (resultRank != (int64_t)outputSizeInt.size()) {
27012722
return rewriter.notifyMatchFailure(
27022723
op, "desired size list length mismatches with the result type rank");
27032724
}
2725+
SmallVector<Value> inputSizeTorchInt = getTensorSizes(rewriter, loc, input);
2726+
ArrayRef<Value> expandedShapeTorchInt =
2727+
llvm::makeArrayRef(isCollapse ? inputSizeTorchInt : outputSizeInt);
2728+
ArrayRef<Value> collapsedShapeTorchInt =
2729+
llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSizeTorchInt);
27042730

2705-
// Check if the `aten.View` can be legalized to `linalg.TensorExpandShape`.
2706-
// It only handles the case of static dimension expansion. If the dimension
2707-
// is dynamic, it must not be expanded/splitted.
2708-
// TODO: Handle the case of dynamic dimension expansion.
2709-
SmallVector<ReassociationIndices> reassociation(inputRank);
2710-
SmallVector<int64_t> resultShape;
2711-
int64_t j = 0;
2712-
for (auto i : llvm::seq<int64_t>(0, inputRank)) {
2713-
if (inputType.isDynamicDim(i)) {
2714-
Value dim = getDimOp(rewriter, loc, input, i);
2715-
if (j >= resultRank) {
2716-
return rewriter.notifyMatchFailure(
2717-
op, "desired size is not compatible with the input tensor size");
2731+
// Iterate through the view op size list to do the following:
2732+
//
2733+
// 1. Combine output size list and input tensor type info to get the most
2734+
// static outputShape.
2735+
//
2736+
// 2. Fill in the reassociation for size list item where the output dim size
2737+
// is got from `torch.aten.size.int(inputTensor, inputDim)`. We naively
2738+
// assume this means the corresponding dimension is not expanded or
2739+
// collapsed. Note this may technically not always be true.
2740+
// TODO: think of a way better way to at least detect when this assumption
2741+
// is violated.
2742+
SmallVector<int64_t> outputShape(resultRank, kUnknownSize);
2743+
SmallVector<ReassociationIndices> reassociation(collapsedRank);
2744+
for (auto en : llvm::enumerate(outputSizeTorchInt)) {
2745+
int64_t inputDim;
2746+
int64_t outputDim = en.index();
2747+
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
2748+
if (matchPattern(en.value(),
2749+
m_TorchTensorSizeInt(op.self(), &inputDim))) {
2750+
auto collapsedDim = isCollapse ? outputDim : inputDim;
2751+
auto expandedDim = isCollapse ? inputDim : outputDim;
2752+
reassociation[collapsedDim].push_back(expandedDim);
2753+
if (!inputType.isDynamicDim(inputDim)) {
2754+
outputShape[outputDim] = inputShape[inputDim];
2755+
continue;
27182756
}
2719-
checkDimEqualHelper(rewriter, loc, dim, expectedSize[j]);
2720-
reassociation[i].push_back(j++);
2721-
resultShape.push_back(kUnknownSize);
2722-
} else {
2723-
int64_t expandedDim = inputType.getDimSize(i);
2724-
int64_t outputDim;
2725-
// A do-while loop is used here to handle the cases where the input
2726-
// tensor has a dimension of size 1.
2727-
do {
2728-
if (j >= resultRank ||
2729-
!matchPattern(expectedSizeTorchInt[j],
2730-
m_TorchConstantInt(&outputDim)) ||
2731-
expandedDim % outputDim != 0) {
2757+
}
2758+
2759+
int64_t size;
2760+
if (matchPattern(en.value(), m_TorchConstantInt(&size)))
2761+
outputShape[outputDim] = size;
2762+
}
2763+
2764+
SmallVector<int64_t> collapsedShape =
2765+
isCollapse ? outputShape : llvm::to_vector(inputShape);
2766+
SmallVector<int64_t> expandedShape =
2767+
isCollapse ? llvm::to_vector(inputShape) : outputShape;
2768+
2769+
// The while loop does the following:
2770+
// 1. Fill in the reassociation indices for dimensions that are expanded.
2771+
// Check the interval dimensions between two unchanged dims in the
2772+
// collapsedShape. If the interval is size 1, associate all the dims
2773+
// in the expandedShape shape until the next unchanged dim. If the interval
2774+
// is larger than size 1, figure out the associations with assumptions that
2775+
// dynamic dimensions are not splitted.
2776+
// 2. Set collapsedShape and expandedShape following the requirements by
2777+
// tensor.expand_shape verification code:
2778+
// a. As long as one or more of the related dimensions in the expanded
2779+
// shape is dynamic the collapsed dimension is dynamic.
2780+
// b. If all of the related dimensions are static, the collapsed
2781+
// dimension must be static. In other words, if a collapsed dimension is
2782+
// dynamic, at least one of the related dimensions need to be dynamic.
2783+
int64_t collapsedDim = 0, expandedDim = 0;
2784+
while (collapsedDim < collapsedRank && expandedDim < expandedRank) {
2785+
// Not empty means the associations has been filled in and the dimension
2786+
// is unchanged.
2787+
if (!reassociation[collapsedDim].empty()) {
2788+
if (expandedDim != reassociation[collapsedDim][0])
2789+
return op.emitOpError("Unsupported: expanded dims are off from the "
2790+
"expected dim got from reassociation");
2791+
collapsedDim++;
2792+
expandedDim++;
2793+
continue;
2794+
}
2795+
2796+
// Collect the dims that are collapsed until hitting the next dim that's
2797+
// unchanged.
2798+
SmallVector<int64_t> collapsedDims;
2799+
while (collapsedDim < collapsedRank &&
2800+
reassociation[collapsedDim].empty()) {
2801+
collapsedDims.push_back(collapsedDim);
2802+
collapsedDim++;
2803+
}
2804+
// the next reassociation is for a dim that's unchanged.
2805+
int64_t expandedDimNext = collapsedDim != collapsedRank
2806+
? reassociation[collapsedDim][0]
2807+
: expandedRank;
2808+
if (collapsedDims.size() == 1) {
2809+
int64_t collapsedDimSize = 1;
2810+
int64_t collapsedDim = collapsedDims[0];
2811+
for (auto i : llvm::seq<int64_t>(expandedDim, expandedDimNext)) {
2812+
reassociation[collapsedDim].push_back(i);
2813+
if (collapsedDimSize == kUnknownSize)
2814+
continue;
2815+
2816+
int64_t expandedDimSize = expandedShape[i];
2817+
if (expandedDimSize == kUnknownSize) {
2818+
collapsedDimSize = kUnknownSize;
2819+
continue;
2820+
}
2821+
collapsedDimSize *= expandedShape[i];
2822+
}
2823+
// To meet both requirements from tensor.expand_shape verification code.
2824+
collapsedShape[collapsedDim] = collapsedDimSize;
2825+
expandedDim = expandedDimNext;
2826+
continue;
2827+
}
2828+
2829+
// collpasedDims are expanded to [expandedDim, expandedDimNext)
2830+
if (expandedDimNext - expandedDim < (int64_t)collapsedDims.size())
2831+
op.emitError("unimplemented: mixed of expanding and collapsing "
2832+
"operations for view");
2833+
for (auto collapsedDim : collapsedDims) {
2834+
if (collapsedShape[collapsedDim] == kUnknownSize) {
2835+
if (expandedDim >= expandedDimNext) {
27322836
return rewriter.notifyMatchFailure(
2733-
op, "total number of elements mismatch in the expansion");
2837+
op,
2838+
"desired size is not compatible with the input tensor size");
27342839
}
2735-
reassociation[i].push_back(j++);
2736-
resultShape.push_back(outputDim);
2737-
expandedDim /= outputDim;
2738-
} while (expandedDim != 1);
2840+
checkDimEqualHelper(rewriter, loc,
2841+
collapsedShapeTorchInt[collapsedDim],
2842+
expandedShapeTorchInt[expandedDim]);
2843+
// To meet the second requirement from tensor.expand_shape
2844+
// verification code.
2845+
expandedShape[expandedDim] = kUnknownSize;
2846+
reassociation[collapsedDim].push_back(expandedDim++);
2847+
} else {
2848+
int64_t remainingSizeToExpand = collapsedShape[collapsedDim];
2849+
// A do-while loop is used here to handle the cases where the
2850+
// collapsed shape tensor has a dimension of size 1.
2851+
do {
2852+
int64_t expandedDimSize = expandedShape[expandedDim];
2853+
if (expandedDim >= expandedDimNext ||
2854+
expandedShape[expandedDim] == kUnknownSize ||
2855+
remainingSizeToExpand % expandedDimSize != 0) {
2856+
return rewriter.notifyMatchFailure(
2857+
op, "total number of elements mismatch in the expansion");
2858+
}
2859+
reassociation[collapsedDim].push_back(expandedDim++);
2860+
remainingSizeToExpand /= expandedDimSize;
2861+
} while (remainingSizeToExpand != 1);
2862+
}
27392863
}
27402864
}
2741-
// Make sure that the splitted dimensions have the same number of elements
2742-
// as the dimension got splitted from.
2743-
if (j != resultRank)
2744-
return rewriter.notifyMatchFailure(
2745-
op, "desired size is not compatible with the input tensor size");
27462865

2747-
Type expandType =
2748-
RankedTensorType::get(resultShape, resultType.getElementType());
2749-
Value expandOp = rewriter.create<linalg::TensorExpandShapeOp>(
2750-
loc, expandType, adaptor.self(), reassociation);
2751-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, expandOp);
2866+
if (collapsedDim != collapsedRank || expandedDim != expandedRank)
2867+
return rewriter.notifyMatchFailure(op, "view shape is not supported");
2868+
Type adjustedResultType =
2869+
RankedTensorType::get(isCollapse ? collapsedShape : expandedShape,
2870+
resultType.getElementType());
2871+
Type adjustedInputType =
2872+
RankedTensorType::get(isCollapse ? expandedShape : collapsedShape,
2873+
resultType.getElementType());
2874+
Value castedInput =
2875+
rewriter.create<tensor::CastOp>(loc, adjustedInputType, input);
2876+
Value result =
2877+
isCollapse
2878+
? rewriter
2879+
.create<linalg::TensorCollapseShapeOp>(
2880+
loc, adjustedResultType, castedInput, reassociation)
2881+
.result()
2882+
: rewriter
2883+
.create<linalg::TensorExpandShapeOp>(
2884+
loc, adjustedResultType, castedInput, reassociation)
2885+
.result();
2886+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
27522887
return success();
27532888
}
27542889
};

0 commit comments

Comments
 (0)