-
Notifications
You must be signed in to change notification settings - Fork 630
Support decomposition of torch.broadcast_tensors #4253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
fd51374
dbef8f4
8ae8aee
cec9b09
45a574a
7707402
7d2831c
f6a2cf8
ef7517e
8c2b762
bd69b6e
f3c0f5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -479,78 +479,117 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter, | |
return unsqueezed; | ||
} | ||
|
||
// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If | ||
// Checks whether the inputs are broadcast compatible or not. If | ||
// yes, then computes the final broadcast shape. | ||
void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, | ||
Value inputA, Value inputB, | ||
SmallVector<Value> inputs, | ||
SmallVector<int64_t> &resultShape, | ||
SmallVector<Value> &resultShapeValue) { | ||
SmallVector<int64_t> shapeA{ | ||
cast<BaseTensorType>(inputA.getType()).getSizes()}; | ||
SmallVector<int64_t> shapeB{ | ||
cast<BaseTensorType>(inputB.getType()).getSizes()}; | ||
unsigned rankA = shapeA.size(); | ||
unsigned rankB = shapeB.size(); | ||
unsigned minRank = rankA > rankB ? rankB : rankA; | ||
|
||
SmallVector<SmallVector<int64_t>> shapes; | ||
SmallVector<unsigned> ranks; | ||
|
||
for (auto input : inputs) { | ||
SmallVector<int64_t> shape{ | ||
cast<BaseTensorType>(input.getType()).getSizes()}; | ||
shapes.push_back(shape); | ||
ranks.push_back(shape.size()); | ||
} | ||
|
||
unsigned maxRank = *std::max_element(ranks.begin(), ranks.end()); | ||
|
||
// Check whether the shapes of the tensors are broadcastable or not. | ||
// Two tensors are “broadcastable” if the following rules hold: | ||
// 1.) Each tensor has at least one dimension. | ||
// 2.) When iterating over the dimension sizes, starting at the trailing | ||
// dimension, the dimension sizes must either be equal, one of them is 1, or | ||
// one of them does not exist. | ||
for (unsigned i = 0; i < minRank; i++) { | ||
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(rankA - i - 1)); | ||
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(rankB - i - 1)); | ||
Value sizeInputA = | ||
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA); | ||
Value sizeInputB = | ||
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB); | ||
for (unsigned i = 0; i < maxRank; i++) { | ||
|
||
SmallVector<Value> sizeInputs; | ||
for (auto [idx, input] : llvm::enumerate(inputs)) { | ||
int sizeDimIdx = ranks[idx] - i - 1; | ||
if (sizeDimIdx >= 0) { | ||
auto sizeDim = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(sizeDimIdx)); | ||
sizeInputs.push_back( | ||
rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim)); | ||
} | ||
} | ||
|
||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>( | ||
vinitdeodhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
loc, rewriter.getI64IntegerAttr(1)); | ||
Value cmpSizeAEqualsSizeB = | ||
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, sizeInputB); | ||
Value cmpSizeAEqualsOne = | ||
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, torchCstOne); | ||
Value cmpSizeBEqualsOne = | ||
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputB, torchCstOne); | ||
SmallVector<Value> predicates; | ||
for (auto sizeVal : sizeInputs) { | ||
Value cmpSizeEquals = | ||
rewriter.create<Torch::AtenEqIntOp>(loc, sizeVal, sizeInputs.front()); | ||
vinitdeodhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
predicates.push_back(cmpSizeEquals); | ||
Value cmpSizeEqualsOne = | ||
rewriter.create<Torch::AtenEqIntOp>(loc, sizeVal, torchCstOne); | ||
predicates.push_back(cmpSizeEqualsOne); | ||
} | ||
|
||
Value anyBoolOpList = rewriter.create<PrimListConstructOp>( | ||
loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()), | ||
SmallVector<Value>{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne, | ||
cmpSizeBEqualsOne}); | ||
loc, Torch::ListType::get(predicates.front().getType()), predicates); | ||
Value cmp = rewriter.create<Torch::AtenAnyBoolOp>(loc, anyBoolOpList); | ||
rewriter.create<Torch::RuntimeAssertOp>( | ||
loc, cmp, "tensors are not broadcast compatible"); | ||
} | ||
|
||
// If we reach here then it means both the shapes are broadcast compatible. | ||
vinitdeodhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
resultShape = rankA >= rankB ? shapeA : shapeB; | ||
Value shapeTensor = rankA >= rankB ? inputA : inputB; | ||
auto maxRankIdx = | ||
std::max_element(ranks.begin(), ranks.end()) - ranks.begin(); | ||
vinitdeodhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
resultShape = shapes[maxRankIdx]; | ||
Value shapeTensor = inputs[maxRankIdx]; | ||
|
||
for (unsigned i = 0; i < resultShape.size(); i++) { | ||
Value sizeDim = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(i)); | ||
resultShapeValue.push_back( | ||
rewriter.createOrFold<AtenSizeIntOp>(loc, shapeTensor, sizeDim)); | ||
} | ||
|
||
unsigned resultRank = resultShape.size(); | ||
for (unsigned i = 0; i < minRank; i++) { | ||
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(rankA - i - 1)); | ||
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(rankB - i - 1)); | ||
Value sizeInputA = | ||
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA); | ||
Value sizeInputB = | ||
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB); | ||
resultShapeValue[resultRank - i - 1] = | ||
rewriter.create<PrimMaxIntOp>(loc, sizeInputA, sizeInputB); | ||
if (shapeA[rankA - i - 1] == kUnknownSize || | ||
shapeB[rankB - i - 1] == kUnknownSize) { | ||
for (unsigned i = 0; i < maxRank; i++) { | ||
|
||
SmallVector<Value> sizeInputs; | ||
for (auto [idx, input] : llvm::enumerate(inputs)) { | ||
vinitdeodhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int sizeDimIdx = ranks[idx] - i - 1; | ||
if (sizeDimIdx >= 0) { | ||
auto sizeDim = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(sizeDimIdx)); | ||
sizeInputs.push_back( | ||
rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim)); | ||
} | ||
} | ||
|
||
// Compute shape value of broadcast result, | ||
// which is the maximum of dimension sizes across all inputs | ||
Value maxShapeVal = sizeInputs.front(); | ||
for (auto sizeInput : sizeInputs) { | ||
maxShapeVal = rewriter.create<PrimMaxIntOp>(loc, maxShapeVal, sizeInput); | ||
} | ||
resultShapeValue[resultRank - i - 1] = maxShapeVal; | ||
|
||
// Compute result shape if all input shapes are known | ||
bool unknownSize = false; | ||
for (auto [idx, shape] : llvm::enumerate(shapes)) { | ||
if (ranks[idx] - i - 1 < shape.size() && | ||
shape[ranks[idx] - i - 1] == kUnknownSize) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zjgarvey , this check ensures that out of bounds access to shape is not performed in case input tensors have different ranks. The e2e test BroadcastTensorsModuleList_multiple_ranks covers this case. Let me know if you have a suggestion on making this check easier to read There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe "confusing" isn't the correct word. I'm rather certain the check is wrong. E.g. For Whereas, |
||
unknownSize = true; | ||
} | ||
} | ||
|
||
if (unknownSize) { | ||
resultShape[resultRank - i - 1] = kUnknownSize; | ||
} else { | ||
resultShape[resultRank - i - 1] = | ||
std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]); | ||
|
||
int64_t maxShape = 1; | ||
for (auto [idx, shape] : llvm::enumerate(shapes)) { | ||
if (ranks[idx] - i - 1 < shape.size()) { | ||
maxShape = std::max(maxShape, shape[ranks[idx] - i - 1]); | ||
} | ||
} | ||
resultShape[resultRank - i - 1] = maxShape; | ||
vinitdeodhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.