Skip to content

Commit 732a76f

Browse files
committed
Make broadcasting result shape more static
This involes the following 2 parts: - Change refine type to propagate more static shape info. - Get as much static shape info as possible when creating the result tensor when converting to linalg.
1 parent b4842d9 commit 732a76f

File tree

6 files changed

+134
-62
lines changed

6 files changed

+134
-62
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ def ElementwiseBinaryModule_basic(module, tu: TestUtils):
6262
# ==============================================================================
6363

6464

65+
class ElementwiseBinaryStaticShapeModule(torch.nn.Module):
66+
def __init__(self):
67+
super().__init__()
68+
69+
@export
70+
@annotate_args([
71+
None,
72+
([5, 4, 3, 3, 1], torch.float32, True),
73+
([4, 3, 1, 2], torch.float32, True),
74+
])
75+
def forward(self, a, b):
76+
return a * b
77+
78+
@register_test_case(
79+
module_factory=lambda: ElementwiseBinaryStaticShapeModule())
80+
def ElementwiseBinaryStaticShapeModule_basic(module, tu: TestUtils):
81+
module.forward(tu.rand(5, 4, 3, 3, 1), tu.rand(4, 3, 1, 2))
82+
83+
84+
# ==============================================================================
85+
86+
6587
class ElementwiseTernaryModule(torch.nn.Module):
6688
def __init__(self):
6789
super().__init__()
@@ -171,8 +193,7 @@ def forward(self, a):
171193
return torch.unsqueeze(a, -3)
172194

173195

174-
@register_test_case(
175-
module_factory=lambda: ElementwiseUnsqueezeNegDimsModule())
196+
@register_test_case(module_factory=lambda: ElementwiseUnsqueezeNegDimsModule())
176197
def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils):
177198
module.forward(tu.rand(4, 3))
178199

@@ -255,7 +276,7 @@ def forward(self, x):
255276

256277
@register_test_case(module_factory=lambda: ElementwiseGeluModule())
257278
def ElementwiseGeluModule_basic(module, tu: TestUtils):
258-
module.forward(2*tu.rand(5, 3) - 0.5)
279+
module.forward(2 * tu.rand(5, 3) - 0.5)
259280

260281

261282
# ==============================================================================
@@ -359,7 +380,7 @@ def forward(self, x):
359380

360381
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
361382
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
362-
module.forward(torch.randint(-10, 15, (3,4)))
383+
module.forward(torch.randint(-10, 15, (3, 4)))
363384

364385

365386
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
@@ -377,7 +398,7 @@ def forward(self, x):
377398

378399
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
379400
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
380-
module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32))
401+
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
381402

382403

383404
class ElementwiseGtFloatTensorModule(torch.nn.Module):
@@ -415,10 +436,12 @@ def forward(self, x, y):
415436

416437
@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
417438
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
418-
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,)))
439+
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
440+
419441

420442
# ==============================================================================
421443

444+
422445
class ElementwiseLtFloatScalarModule(torch.nn.Module):
423446
def __init__(self):
424447
super().__init__()
@@ -452,7 +475,7 @@ def forward(self, x):
452475

453476
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
454477
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
455-
module.forward(torch.randint(-10, 15, (3,4)))
478+
module.forward(torch.randint(-10, 15, (3, 4)))
456479

457480

458481
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
@@ -468,9 +491,10 @@ def forward(self, x):
468491
return torch.lt(x, 2)
469492

470493

471-
@register_test_case(module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
494+
@register_test_case(
495+
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
472496
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
473-
module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32))
497+
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
474498

475499

476500
class ElementwiseLtFloatTensorModule(torch.nn.Module):
@@ -508,10 +532,12 @@ def forward(self, x, y):
508532

509533
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
510534
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
511-
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,)))
535+
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
536+
512537

513538
# ==============================================================================
514539

540+
515541
class ElementwiseEqFloatScalarModule(torch.nn.Module):
516542
def __init__(self):
517543
super().__init__()
@@ -527,8 +553,8 @@ def forward(self, x):
527553

528554
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
529555
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
530-
module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]])
531-
.to(torch.float32))
556+
module.forward(
557+
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
532558

533559

534560
class ElementwiseEqIntScalarModule(torch.nn.Module):
@@ -546,7 +572,7 @@ def forward(self, x):
546572

547573
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
548574
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
549-
module.forward(torch.randint(2, 4, (5,8)))
575+
module.forward(torch.randint(2, 4, (5, 8)))
550576

551577

552578
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
@@ -562,9 +588,10 @@ def forward(self, x):
562588
return torch.eq(x, 2)
563589

564590

565-
@register_test_case(module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
591+
@register_test_case(
592+
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
566593
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
567-
module.forward(torch.randint(2, 4, (5,8)).to(torch.int32))
594+
module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32))
568595

569596

570597
class ElementwiseEqFloatTensorModule(torch.nn.Module):
@@ -583,9 +610,9 @@ def forward(self, x, y):
583610

584611
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
585612
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
586-
module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]])
587-
.to(torch.float32),
588-
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
613+
module.forward(
614+
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
615+
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
589616

590617

591618
class ElementwiseEqIntTensorModule(torch.nn.Module):
@@ -604,10 +631,12 @@ def forward(self, x, y):
604631

605632
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
606633
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
607-
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5,)))
634+
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
635+
608636

609637
# ==============================================================================
610638

639+
611640
class ElementwiseClampModule(torch.nn.Module):
612641
def __init__(self):
613642
super().__init__()
@@ -666,7 +695,7 @@ def forward(self, x):
666695
@register_test_case(module_factory=lambda: RsubModule_noalpha())
667696
def RsubModule_noalpha_basic(module, tu: TestUtils):
668697
module.forward(tu.rand(3, 4))
669-
698+
670699
# ==============================================================================
671700

672701
class ElementwiseMulScalarIntModule(torch.nn.Module):
@@ -734,12 +763,10 @@ def forward(self, a, b):
734763
return torch.mul(a, b)
735764

736765

737-
@register_test_case(
738-
module_factory=lambda: ElementwiseMulTensorFloatModule())
766+
@register_test_case(module_factory=lambda: ElementwiseMulTensorFloatModule())
739767
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
740-
module.forward(
741-
tu.rand(4),
742-
tu.rand(4).type(torch.float64))
768+
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
769+
743770

744771
class ElementwiseMulTensorIntModule(torch.nn.Module):
745772
def __init__(self):
@@ -755,12 +782,10 @@ def forward(self, a, b):
755782
return torch.mul(a, b)
756783

757784

758-
@register_test_case(
759-
module_factory=lambda: ElementwiseMulTensorIntModule())
785+
@register_test_case(module_factory=lambda: ElementwiseMulTensorIntModule())
760786
def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
761787
module.forward(
762-
torch.randint(10, [4]).type(torch.int32),
763-
torch.randint(10, [4]))
788+
torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4]))
764789

765790

766791
# ==============================================================================
@@ -783,7 +808,7 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
783808

784809

785810
class ElementwiseSqrtModule(torch.nn.Module):
786-
def __init__(self):
811+
def __init__(self):
787812
super().__init__()
788813

789814
@export
@@ -898,7 +923,7 @@ def ElementwiseLog2Module_basic(module, tu: TestUtils):
898923
module.forward(tu.rand(3, 4))
899924

900925
class ElementwiseRsqrtModule(torch.nn.Module):
901-
def __init__(self):
926+
def __init__(self):
902927
super().__init__()
903928

904929
@export
@@ -984,12 +1009,9 @@ def forward(self, a, b):
9841009
return torch.div(a, b)
9851010

9861011

987-
@register_test_case(
988-
module_factory=lambda: ElementwiseDivTensorFloatModule())
1012+
@register_test_case(module_factory=lambda: ElementwiseDivTensorFloatModule())
9891013
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
990-
module.forward(
991-
tu.rand(4),
992-
tu.rand(4).type(torch.float64))
1014+
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
9931015

9941016

9951017
# ==============================================================================
@@ -1005,15 +1027,15 @@ def __init__(self):
10051027
([-1, -1], torch.int32, True),
10061028
([-1, -1], torch.int64, True),
10071029
])
1008-
10091030
def forward(self, x, y):
10101031
return torch.bitwise_and(x, y)
10111032

10121033

10131034
@register_test_case(module_factory=lambda: ElementwiseAndIntegerModule())
10141035
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
1015-
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
1016-
torch.randint(-10, 10, (3, 4)))
1036+
module.forward(
1037+
torch.randint(-10, 10, (3, 4)).to(torch.int32),
1038+
torch.randint(-10, 10, (3, 4)))
10171039

10181040

10191041
class ElementwiseSubScalarIntModule(torch.nn.Module):
@@ -1026,7 +1048,8 @@ def __init__(self):
10261048
([-1, -1], torch.int64, True),
10271049
])
10281050
def forward(self, x):
1029-
return torch.sub(x, 2.1, alpha = 2)
1051+
return torch.sub(x, 2.1, alpha=2)
1052+
10301053

10311054
@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule())
10321055
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
@@ -1077,7 +1100,8 @@ def __init__(self):
10771100
([-1, -1], torch.float32, True),
10781101
])
10791102
def forward(self, x):
1080-
return torch.add(x, 3.0, alpha = 2)
1103+
return torch.add(x, 3.0, alpha=2)
1104+
10811105

10821106
@register_test_case(module_factory=lambda: ElementwiseAddScalarFloatModule())
10831107
def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"ElementwiseReluModule_basic",
2929
"ElementwiseFloorModule_basic",
3030
"ElementwiseLogModule_basic",
31+
"ElementwiseBinaryStaticShapeModule_basic",
3132
"TanhBackward_basic",
3233
"ElementwiseAddModule_basic",
3334
"ReturnThreeTensorFloat32_basic",

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,7 +2345,7 @@ struct ConvertElementwiseOp : ConversionPattern {
23452345
// undefined behavior, by doing appropriate checks against the current
23462346
// dimension size.
23472347
auto currentDimSize =
2348-
rewriter.create<tensor::DimOp>(loc, tensorOperand, size.index());
2348+
getDimOp(rewriter, loc, tensorOperand, size.index());
23492349

23502350
// If the result size of this dimension has so far only hit the
23512351
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
@@ -2372,12 +2372,13 @@ struct ConvertElementwiseOp : ConversionPattern {
23722372
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext()));
23732373
}
23742374

2375-
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
2375+
SmallVector<StringRef> iteratorTypes(resultRank,
2376+
getParallelIteratorTypeName());
23762377
// Add the indexing map for the outs init tensor.
23772378
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
23782379

23792380
Value initTensor = rewriter.create<linalg::InitTensorOp>(
2380-
loc, resultShape, resultType.getElementType());
2381+
loc, getAsOpFoldResult(resultShape), resultType.getElementType());
23812382
bool hadErrorCreatingPayload = false;
23822383
auto generic = rewriter.create<linalg::GenericOp>(
23832384
loc, /*resultTensorTypes=*/initTensor.getType(),

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,47 @@ static void fillInSizesGivenSizesList(ValueKnowledge &knowledge, Value sizes) {
759759
}
760760
}
761761

762+
static void fillInSizesForBinaryBroadcastingOp(ValueKnowledge &lhs,
763+
ValueKnowledge &rhs,
764+
ValueKnowledge &knowledge) {
765+
if (lhs.hasSizes && rhs.hasSizes) {
766+
knowledge.hasSizes = true;
767+
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
768+
kUnknownSize);
769+
770+
int64_t resultRank = knowledge.sizes.size();
771+
auto increaseRankToResultRank =
772+
[&](const std::vector<int64_t> &sizes) -> std::vector<int64_t> {
773+
int offset = resultRank - sizes.size();
774+
std::vector<int64_t> newSizes(std::max(offset, 0), 1);
775+
newSizes.insert(newSizes.end(), sizes.begin(), sizes.end());
776+
return newSizes;
777+
};
778+
779+
std::vector<int64_t> rankAdjustedSizesLhs =
780+
increaseRankToResultRank(lhs.sizes);
781+
std::vector<int64_t> rankAdjustedSizesRhs =
782+
increaseRankToResultRank(rhs.sizes);
783+
784+
for (int64_t i = 0; i < resultRank; i++) {
785+
int64_t lhsDimSize = rankAdjustedSizesLhs[i];
786+
int64_t rhsDimSize = rankAdjustedSizesRhs[i];
787+
// Dynamic shape can't be decided at compilation.
788+
if (lhsDimSize == kUnknownSize || rhsDimSize == kUnknownSize)
789+
continue;
790+
791+
// Incompatible broadcasting shape.
792+
if (lhsDimSize != rhsDimSize && lhsDimSize != 1 && rhsDimSize != 1) {
793+
knowledge.hasSizes = false;
794+
knowledge.sizes.clear();
795+
return;
796+
}
797+
798+
knowledge.sizes[i] = std::max(lhsDimSize, rhsDimSize);
799+
}
800+
}
801+
}
802+
762803
ChangeResult TypeAnalyzer::visitAtenMmOp(
763804
AtenMmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
764805
auto &lhs = operands[0]->getValue();
@@ -950,11 +991,7 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp(
950991
auto rhs = operands[1]->getValue();
951992
auto knowledge =
952993
ValueKnowledge::getNotNonePessimisticValueState(getContext());
953-
if (lhs.hasSizes && rhs.hasSizes) {
954-
knowledge.hasSizes = true;
955-
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
956-
kUnknownSize);
957-
}
994+
fillInSizesForBinaryBroadcastingOp(lhs, rhs, knowledge);
958995

959996
// The alpha in `aten.add.Tensor` and `aten.sub.Tensor` has to be lower type
960997
// category than the lhs and rhs and therefore doesn't really contribute to
@@ -969,12 +1006,8 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingComparisonOp(
9691006
auto rhs = operands[1]->getValue();
9701007
auto knowledge =
9711008
ValueKnowledge::getNotNonePessimisticValueState(getContext());
972-
if (lhs.hasSizes && rhs.hasSizes) {
973-
knowledge.hasSizes = true;
974-
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
975-
kUnknownSize);
976-
}
977-
knowledge.dtype = IntegerType::get(op->getContext(), 1);
1009+
fillInSizesForBinaryBroadcastingOp(lhs, rhs, knowledge);
1010+
knowledge.dtype = IntegerType::get(op->getContext(), 1);
9781011
return getLatticeElement(op->getResult(0)).join(knowledge);
9791012
}
9801013

0 commit comments

Comments
 (0)