@@ -4426,88 +4426,92 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
4426
4426
if (!selfTy.hasSizes ())
4427
4427
return rewriter.notifyMatchFailure (op, " input sizes unknown" );
4428
4428
4429
- // Materialize out 1 dimensions to broadcast along. This includes
4430
- // materializing out preceding batch dimensions:
4431
- for (int i = 0 ; i < repeatSz ; ++i) {
4432
- auto oldSizes = selfTy. getSizes () ;
4433
- llvm::SmallVector< int64_t > sizes;
4434
- int64_t squeezeDim = i < batch ? i : i * 2 - batch ;
4429
+ // Fold the constant values so that we know which we materialize:
4430
+ llvm::SmallVector< int64_t > repeatInts;
4431
+ for (int i = 0 , s = repeats. size () ; i < s ; ++i) {
4432
+ int64_t repeat ;
4433
+ if (! matchPattern (repeats[i], m_TorchConstantInt (&repeat)))
4434
+ repeat = Torch:: kUnknownSize ;
4435
4435
4436
- for (int j = 0 ; j < squeezeDim; ++j)
4437
- sizes.push_back (oldSizes[j]);
4438
- sizes.push_back (1 );
4439
- for (int j = squeezeDim, s = oldSizes.size (); j < s; j++)
4440
- sizes.push_back (oldSizes[j]);
4436
+ repeatInts.push_back (repeat);
4437
+ }
4438
+
4439
+ // Unsqueeze all newly created dims
4440
+ llvm::SmallVector<int > unsqueezeDims;
4441
+ for (int i = 0 ; i < batch; ++i) {
4442
+ Value iv =
4443
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (i));
4444
+ self = *unsqueezeTensor (rewriter, op, self, iv);
4445
+ selfTy = cast<ValueTensorType>(self.getType ());
4446
+ unsqueezeDims.push_back (i);
4447
+ }
4441
4448
4442
- Value dim = rewriter.create <Torch::ConstantIntOp>(loc, squeezeDim);
4443
- selfTy =
4444
- rewriter.getType <ValueTensorType>(sizes, selfTy.getOptionalDtype ());
4445
- self = rewriter.create <AtenUnsqueezeOp>(loc, selfTy, self, dim);
4449
+ // Unsqueeze any non-unary repeats for existing dims
4450
+ for (int i = batch, s = repeats.size (); i < s; ++i) {
4451
+ if (repeatInts[i] == 1 )
4452
+ continue ;
4453
+ int64_t dim = i + unsqueezeDims.size () - batch;
4454
+ Value iv =
4455
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (dim));
4456
+ self = *unsqueezeTensor (rewriter, op, self, iv);
4457
+ selfTy = cast<ValueTensorType>(self.getType ());
4458
+ unsqueezeDims.push_back (dim);
4446
4459
}
4447
4460
4461
+ // Materialize the expansion sizes for each dim:
4448
4462
llvm::SmallVector<Value> lengths;
4449
- for (int i = 0 ; i < repeatSz; ++i) {
4450
- if (i < batch) {
4463
+ llvm::SmallVector<int64_t > expandShape;
4464
+ for (int i = 0 ; i < batch; ++i) {
4465
+ lengths.push_back (repeats[i]);
4466
+ expandShape.push_back (repeatInts[i]);
4467
+ }
4468
+
4469
+ for (int i = batch, s = repeats.size (); i < s; ++i) {
4470
+ if (repeatInts[i] != 1 ) {
4451
4471
lengths.push_back (repeats[i]);
4452
- continue ;
4472
+ expandShape. push_back (repeatInts[i]) ;
4453
4473
}
4454
4474
4455
- Value iv = rewriter.create <ConstantIntOp>(
4456
- loc, rewriter.getI64IntegerAttr (i * 2 + 1 - batch));
4457
- Value dim = rewriter.create <AtenSizeIntOp>(loc, self, /* dim=*/ iv);
4458
- lengths.push_back (repeats[i]);
4459
- lengths.push_back (dim);
4475
+ int dim = lengths.size ();
4476
+ Value iv =
4477
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (dim));
4478
+ Value dimV = rewriter.create <AtenSizeIntOp>(loc, self, /* dim=*/ iv);
4479
+ lengths.push_back (dimV);
4480
+ expandShape.push_back (selfTy.getSizes ()[dim]);
4460
4481
}
4461
4482
4483
+ // Materialize the broadcast:
4462
4484
Value lengthv = rewriter.create <PrimListConstructOp>(
4463
4485
loc, ListType::get (rewriter.getType <IntType>()), lengths);
4486
+ selfTy = rewriter.getType <ValueTensorType>(expandShape,
4487
+ selfTy.getOptionalDtype ());
4488
+ self = rewriter.create <AtenBroadcastToOp>(loc, selfTy, self, lengthv);
4464
4489
4465
- llvm::SmallVector<int64_t > expandShape (selfTy.getSizes ());
4466
- for (int i = 0 ; i < repeatSz; ++i) {
4467
- int64_t repeatDim = i < batch ? i : i * 2 - batch;
4468
- int64_t repeat;
4469
- if (!matchPattern (repeats[i], m_TorchConstantInt (&repeat)))
4470
- repeat = Torch::kUnknownSize ;
4471
- expandShape[repeatDim] = repeat;
4472
- }
4490
+ auto outShape = cast<ValueTensorType>(op.getResult ().getType ()).getSizes ();
4491
+ for (int i = batch, s = repeats.size (); i < s; ++i) {
4492
+ if (repeatInts[i] == 1 )
4493
+ continue ;
4473
4494
4474
- auto mulDim = [](int64_t lhs, int64_t rhs) {
4475
- if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize )
4476
- return Torch::kUnknownSize ;
4477
- return lhs * rhs;
4478
- };
4495
+ auto selfShape = selfTy.getSizes ();
4496
+ llvm::SmallVector<int64_t > flattenShape;
4497
+ for (int j = 0 ; j <= i; ++j)
4498
+ flattenShape.push_back (outShape[j]);
4479
4499
4480
- BaseTensorType expandTy = rewriter.getType <ValueTensorType>(
4481
- expandShape, selfTy.getOptionalDtype ());
4482
- Value expand =
4483
- rewriter.create <AtenBroadcastToOp>(loc, expandTy, self, lengthv);
4500
+ for (int j = i + 2 , s = selfShape.size (); j < s; ++j)
4501
+ flattenShape.push_back (selfShape[j]);
4484
4502
4485
- for (int i = 0 ; i < rank; ++i) {
4486
- auto oldShape = expandTy.getSizes ();
4487
- llvm::SmallVector<int64_t > newShape;
4488
- int64_t flattenDim = i + batch;
4489
- for (int j = 0 ; j < flattenDim; ++j)
4490
- newShape.push_back (oldShape[j]);
4491
- newShape.push_back (
4492
- mulDim (oldShape[flattenDim], oldShape[flattenDim + 1 ]));
4493
- for (int j = flattenDim + 2 , s = oldShape.size (); j < s; ++j)
4494
- newShape.push_back (oldShape[j]);
4495
-
4496
- expandTy = rewriter.getType <ValueTensorType>(newShape,
4497
- expandTy.getOptionalDtype ());
4498
-
4499
- // Used to keep the return type the same on the last flatten:
4500
- expandTy = i < rank - 1 ? expandTy : cast<BaseTensorType>(op.getType ());
4501
-
4502
- Value start = rewriter.create <ConstantIntOp>(
4503
- loc, rewriter.getI64IntegerAttr (flattenDim));
4503
+ selfTy = rewriter.getType <ValueTensorType>(flattenShape,
4504
+ selfTy.getOptionalDtype ());
4505
+ Value start =
4506
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (i));
4504
4507
Value end = rewriter.create <ConstantIntOp>(
4505
- loc, rewriter.getI64IntegerAttr (flattenDim + 1 ));
4506
- expand = rewriter.create <AtenFlattenUsingIntsOp>(loc, expandTy, expand,
4507
- start, end);
4508
+ loc, rewriter.getI64IntegerAttr (i + 1 ));
4509
+
4510
+ self = rewriter.create <AtenFlattenUsingIntsOp>(loc, selfTy, self, start,
4511
+ end);
4508
4512
}
4509
4513
4510
- rewriter.replaceOp (op, expand );
4514
+ rewriter.replaceOp (op, self );
4511
4515
return success ();
4512
4516
}
4513
4517
};
0 commit comments