Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
return transpose;
}

static int64_t getDimFromValue(Value dimValue) {
if (auto constOp = dimValue.getDefiningOp<arith::ConstantOp>()) {
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
return intAttr.getInt();
}
}
return ShapedType::kDynamic;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this exact function likely already exists in some Utils.cpp file.

There may even be a computeBroadcastShape util function at this point, so I'd take a look and see if we can re-use some of that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! I will have a look.

class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -505,9 +514,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {

// Broadcast the batch dimensions of both the matrices.
Value broadcastedLhs, broadcastedRhs;
// TODO: Improve usage of static shape information.
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
ShapedType::kDynamic);
SmallVector<int64_t> lhsTargetShape = llvm::to_vector(
llvm::map_range(lhsBroadcastToShape, getDimFromValue));

auto lhsBroadcastType = RankedTensorType::get(
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
Expand All @@ -516,8 +525,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
ShapedType::kDynamic);
SmallVector<int64_t> rhsTargetShape = llvm::to_vector(
llvm::map_range(rhsBroadcastToShape, getDimFromValue));
auto rhsBroadcastType = RankedTensorType::get(
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
Expand Down
Loading