Skip to content

Commit 0d0653a

Browse files
authored
[torch] Rework torch.repeat to not broadcast unary case (#4061)
Not all dimensions in `torch.repeat` may need to be broadcasted. Skip unsqueezing and flattening these dimensions together.
1 parent 91a6c15 commit 0d0653a

File tree

3 files changed

+97
-68
lines changed

3 files changed

+97
-68
lines changed

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,37 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
457457
return success();
458458
}
459459

460+
// Flatten size-1 broadcast dims to simplify the final generic op.
461+
// If all dims are size-1 broadcast dims, then this will collapse to a
462+
// rank-0 tensor.
463+
SmallVector<ReassociationIndices> collapseExprs;
464+
for (int64_t i = 0, e = inputRank; i < e; ++i) {
465+
if (!broadcastedStatus[i]) {
466+
collapseExprs.push_back({});
467+
}
468+
}
469+
470+
int64_t previous = -1;
471+
bool collapse = false;
460472
SmallVector<AffineExpr> inputExprs;
461473
for (int64_t i = 0, e = inputRank; i < e; ++i) {
462-
if (broadcastedStatus[i]) {
463-
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
474+
if (!broadcastedStatus[i]) {
475+
previous++;
476+
collapseExprs[previous].push_back(i);
477+
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
464478
continue;
465479
}
466-
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
480+
481+
int64_t clamped = previous < 0 ? 0 : previous;
482+
if (!collapseExprs.empty()) {
483+
collapseExprs[clamped].push_back(i);
484+
}
485+
collapse = true;
486+
}
487+
488+
if (collapse) {
489+
input = rewriter.create<tensor::CollapseShapeOp>(op->getLoc(), input,
490+
collapseExprs);
467491
}
468492

469493
SmallVector<AffineMap> indexingMaps = {

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 67 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4426,88 +4426,92 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
44264426
if (!selfTy.hasSizes())
44274427
return rewriter.notifyMatchFailure(op, "input sizes unknown");
44284428

4429-
// Materialize out 1 dimensions to broadcast along. This includes
4430-
// materializing out preceding batch dimensions:
4431-
for (int i = 0; i < repeatSz; ++i) {
4432-
auto oldSizes = selfTy.getSizes();
4433-
llvm::SmallVector<int64_t> sizes;
4434-
int64_t squeezeDim = i < batch ? i : i * 2 - batch;
4429+
// Fold the constant values so that we know which we materialize:
4430+
llvm::SmallVector<int64_t> repeatInts;
4431+
for (int i = 0, s = repeats.size(); i < s; ++i) {
4432+
int64_t repeat;
4433+
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
4434+
repeat = Torch::kUnknownSize;
44354435

4436-
for (int j = 0; j < squeezeDim; ++j)
4437-
sizes.push_back(oldSizes[j]);
4438-
sizes.push_back(1);
4439-
for (int j = squeezeDim, s = oldSizes.size(); j < s; j++)
4440-
sizes.push_back(oldSizes[j]);
4436+
repeatInts.push_back(repeat);
4437+
}
4438+
4439+
// Unsqueeze all newly created dims
4440+
llvm::SmallVector<int> unsqueezeDims;
4441+
for (int i = 0; i < batch; ++i) {
4442+
Value iv =
4443+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
4444+
self = *unsqueezeTensor(rewriter, op, self, iv);
4445+
selfTy = cast<ValueTensorType>(self.getType());
4446+
unsqueezeDims.push_back(i);
4447+
}
44414448

4442-
Value dim = rewriter.create<Torch::ConstantIntOp>(loc, squeezeDim);
4443-
selfTy =
4444-
rewriter.getType<ValueTensorType>(sizes, selfTy.getOptionalDtype());
4445-
self = rewriter.create<AtenUnsqueezeOp>(loc, selfTy, self, dim);
4449+
// Unsqueeze any non-unary repeats for existing dims
4450+
for (int i = batch, s = repeats.size(); i < s; ++i) {
4451+
if (repeatInts[i] == 1)
4452+
continue;
4453+
int64_t dim = i + unsqueezeDims.size() - batch;
4454+
Value iv =
4455+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
4456+
self = *unsqueezeTensor(rewriter, op, self, iv);
4457+
selfTy = cast<ValueTensorType>(self.getType());
4458+
unsqueezeDims.push_back(dim);
44464459
}
44474460

4461+
// Materialize the expansion sizes for each dim:
44484462
llvm::SmallVector<Value> lengths;
4449-
for (int i = 0; i < repeatSz; ++i) {
4450-
if (i < batch) {
4463+
llvm::SmallVector<int64_t> expandShape;
4464+
for (int i = 0; i < batch; ++i) {
4465+
lengths.push_back(repeats[i]);
4466+
expandShape.push_back(repeatInts[i]);
4467+
}
4468+
4469+
for (int i = batch, s = repeats.size(); i < s; ++i) {
4470+
if (repeatInts[i] != 1) {
44514471
lengths.push_back(repeats[i]);
4452-
continue;
4472+
expandShape.push_back(repeatInts[i]);
44534473
}
44544474

4455-
Value iv = rewriter.create<ConstantIntOp>(
4456-
loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch));
4457-
Value dim = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
4458-
lengths.push_back(repeats[i]);
4459-
lengths.push_back(dim);
4475+
int dim = lengths.size();
4476+
Value iv =
4477+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
4478+
Value dimV = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
4479+
lengths.push_back(dimV);
4480+
expandShape.push_back(selfTy.getSizes()[dim]);
44604481
}
44614482

4483+
// Materialize the broadcast:
44624484
Value lengthv = rewriter.create<PrimListConstructOp>(
44634485
loc, ListType::get(rewriter.getType<IntType>()), lengths);
4486+
selfTy = rewriter.getType<ValueTensorType>(expandShape,
4487+
selfTy.getOptionalDtype());
4488+
self = rewriter.create<AtenBroadcastToOp>(loc, selfTy, self, lengthv);
44644489

4465-
llvm::SmallVector<int64_t> expandShape(selfTy.getSizes());
4466-
for (int i = 0; i < repeatSz; ++i) {
4467-
int64_t repeatDim = i < batch ? i : i * 2 - batch;
4468-
int64_t repeat;
4469-
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
4470-
repeat = Torch::kUnknownSize;
4471-
expandShape[repeatDim] = repeat;
4472-
}
4490+
auto outShape = cast<ValueTensorType>(op.getResult().getType()).getSizes();
4491+
for (int i = batch, s = repeats.size(); i < s; ++i) {
4492+
if (repeatInts[i] == 1)
4493+
continue;
44734494

4474-
auto mulDim = [](int64_t lhs, int64_t rhs) {
4475-
if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize)
4476-
return Torch::kUnknownSize;
4477-
return lhs * rhs;
4478-
};
4495+
auto selfShape = selfTy.getSizes();
4496+
llvm::SmallVector<int64_t> flattenShape;
4497+
for (int j = 0; j <= i; ++j)
4498+
flattenShape.push_back(outShape[j]);
44794499

4480-
BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
4481-
expandShape, selfTy.getOptionalDtype());
4482-
Value expand =
4483-
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, lengthv);
4500+
for (int j = i + 2, s = selfShape.size(); j < s; ++j)
4501+
flattenShape.push_back(selfShape[j]);
44844502

4485-
for (int i = 0; i < rank; ++i) {
4486-
auto oldShape = expandTy.getSizes();
4487-
llvm::SmallVector<int64_t> newShape;
4488-
int64_t flattenDim = i + batch;
4489-
for (int j = 0; j < flattenDim; ++j)
4490-
newShape.push_back(oldShape[j]);
4491-
newShape.push_back(
4492-
mulDim(oldShape[flattenDim], oldShape[flattenDim + 1]));
4493-
for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j)
4494-
newShape.push_back(oldShape[j]);
4495-
4496-
expandTy = rewriter.getType<ValueTensorType>(newShape,
4497-
expandTy.getOptionalDtype());
4498-
4499-
// Used to keep the return type the same on the last flatten:
4500-
expandTy = i < rank - 1 ? expandTy : cast<BaseTensorType>(op.getType());
4501-
4502-
Value start = rewriter.create<ConstantIntOp>(
4503-
loc, rewriter.getI64IntegerAttr(flattenDim));
4503+
selfTy = rewriter.getType<ValueTensorType>(flattenShape,
4504+
selfTy.getOptionalDtype());
4505+
Value start =
4506+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
45044507
Value end = rewriter.create<ConstantIntOp>(
4505-
loc, rewriter.getI64IntegerAttr(flattenDim + 1));
4506-
expand = rewriter.create<AtenFlattenUsingIntsOp>(loc, expandTy, expand,
4507-
start, end);
4508+
loc, rewriter.getI64IntegerAttr(i + 1));
4509+
4510+
self = rewriter.create<AtenFlattenUsingIntsOp>(loc, selfTy, self, start,
4511+
end);
45084512
}
45094513

4510-
rewriter.replaceOp(op, expand);
4514+
rewriter.replaceOp(op, self);
45114515
return success();
45124516
}
45134517
};

test/Conversion/TorchToLinalg/broadcast.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32
2222

2323
// CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast(
2424
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32>
25+
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
2526
// CHECK: %[[GENERIC:.*]] = linalg.generic
26-
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
27+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
2728
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
28-
// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) {
29+
// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) {
2930
// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
3031
// CHECK: linalg.yield %[[IN]] : f32
3132
// CHECK: } -> tensor<1x4x2xf32>

0 commit comments

Comments
 (0)