Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

// Forward declaration of Conv1DGenerator and its validator
namespace {
struct Conv1DGenerator;
bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
} // namespace

/// Try to vectorize `convOp` as a convolution.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
Expand Down Expand Up @@ -1990,8 +1996,21 @@ static LogicalResult vectorizeLinalgOpPrecondition(
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
// Create a dummy rewriter first, a rewriter is not required for
// validation
IRRewriter dummyBuilder(linalgOp.getContext());
// Check if we can successfully construct a 1d convolution generator.
// For example, if it is 2d+ convolution, return failure because we don't
// support it. To use this pass on a 2d+ convolution, it should have already
// been decomposed to 1d convolution via
// DecomposeConvolutionToLowerDimOpsPass.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find such Pass in-tree.

Copy link
Member Author

@jerryyin jerryyin Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologize, this is an iree implementation detail and referenced pass is also from IREE. I shouldn't include such pass in upstream. Will remove.

I am proposing to re-use the high-level logic that's already available.

Let me use this thread to discuss it. This diff code block is the high-level pre condition check around linalg op. Please let me know if you are referring to a different location.

From what I can tell, it would be totally fine to check the dims of conv ops very early on.

Could you elaborate? Are you referring to explicitly invoke inferConvolutionDims()? Then to make sure this is a regular 2d convolution, I'd check for a combination of:

  • outputImage.size() == 2
  • batch.size() == 1
  • outputChannel.size() == 1

Reject if all of those satisfy. I have no problem to implement this, but just want to make sure we are on the same page.


Taking a step back, I don't have a lot of context about the history of the vectorization code around convolution. Since this PR is not intending to do a massive re-write, I'm attempting to be coherent with the existing code as much as possible.

One thing I've noticed and @hanhanW who righteously pointed out is that we can fail to build a Conv1DGenerator, and still allow a function (like how vectorizeConvolution() construct and uses the Conv1DGenerator) invoked on its member vectorization functions, which I find to be quite confusing. (If I'm to implement this from scratch, I'll probably use singleton + initialize compared to the approach (constructor + valid member variable). This way, a developer is required to invoke the initialize method and check validity of the class before invoking anything on it.)

With this context, I find the most defensive approach is the one used from this PR right now:

  • With future implementation to be added and more flavor of convolution supported, it is very likely that the precondition check on vectorize convolution grow out of sync (and this PR is a perfect example)
  • Now instead of maintain a separate function that does a subset of the constructor logic, why not re-use it and ensure we do the validity check? This looks reasonable as the constructor is (if not better, at least) not more expensive than having to infer the convolution dimensions.

With above reasoning added up, it just looks to me to be a better solution compared with inferring the convolution dimensions and reject a few corner cases (which can easily grow out-of-sync later).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me use this thread to discuss it. This diff code block is the high-level pre condition check around linalg op. Please let me know if you are referring to a different location.

Similarly to Diego, I am suggesting to move the high level logic to vectorizeOpPrecondition. Also, like other "pre-condition" hooks, it should not require a rewriter.

Could you elaborate? Are you referring to explicitly invoke inferConvolutionDims()? Then to make sure this is a regular 2d convolution, I'd check for a combination of:

  • outputImage.size() == 2
  • batch.size() == 1
  • outputChannel.size() == 1

Reject if all of those satisfy. I have no problem to implement this, but just want to make sure we are on the same page.

In my naivety, I was hoping that checking e.g. the rank of the filter or the input would be sufficient. But clearly not - the input for non-channelled conv would be 1D, but for a channeled one would be 2D. So on and so forth. IMO, you can just create something like this:

if (!isa<conv_type1, conv_type2, ...>(conv))
   return failure();

This will be a bit verbose, but there's just too many convs and whatever we try will be ... verbose 🤷🏻

Taking a step back, I don't have a lot of context about the history of the vectorization code around convolution. Since this PR is not intending to do a massive re-write, I'm attempting to be coherent with the existing code as much as possible.

+1 to being coherent, thanks!

I was actually going to ask - do you have any plans regarding this code beyond this PR?

One thing I've noticed and @hanhanW who righteously pointed out is that we can fail to build a Conv1DGenerator, and still allow a function (like how vectorizeConvolution() construct and uses the Conv1DGenerator) invoked on its member vectorization functions, which I find to be quite confusing. (If I'm to implement this from scratch, I'll probably use singleton + initialize compared to the approach (constructor + valid member variable). This way, a developer is required to invoke the initialize method and check validity of the class before invoking anything on it.)

You should be able to simply add:

assert(isValid() && "Conv1DGenerator failed")

From what I can tell, that wouldn't break any tests and will make "validity" a strong pre-requisite.

With this context, I find the most defensive approach is the one used from this PR right now:

  • With future implementation to be added and more flavor of convolution supported, it is very likely that the precondition check on vectorize convolution grow out of sync (and this PR is a perfect example)

There's been no new implementations in > 2 yrs. From what I can tell, we can safely assume that this will remain the case for the foreseeable future. So, I wouldn't worry about this.

  • Now instead of maintain a separate function that does a subset of the constructor logic, why not re-use it and ensure we do the validity check? This looks reasonable as the constructor is (if not better, at least) not more expensive than having to infer the convolution dimensions.

That sounds good in theory, but in practice it means that we need an IR writer for the validation. "Validation"/"pre-conditioning" should not require a rewriter.

With above reasoning added up, it just looks to me to be a better solution compared with inferring the convolution dimensions and reject a few corner cases (which can easily grow out-of-sync later).

How about my suggestion with isa?

Copy link
Member Author

@jerryyin jerryyin Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really appreciate your thorough review comments which gives me a ton of useful information.

I was actually going to ask - do you have any plans regarding this code beyond this PR?

Thanks for asking! I don't have any further plans... Was only meant to unblock myself from a non-relevant crash that will fail downstream tests.

That sounds good in theory, but in practice it means that we need an IR writer for the validation. "Validation"/"pre-conditioning" should not require a rewriter.

Agreed that I don't like to have a redundant dummy rewriter just for the validation too. In fact, I took a second look at all the instances of places where a Conv1DGenerator's member function is invoked and find that all the places have access to a rewriter. The need for the rewriter really only comes from the base class StructuredGenerator constructor. Then, I'm also surprised to find that the base class StructuredGenerator doesn't use the rewriter yet it unnecessarily stored this as the state to this class. A slightly more aggressive way is to get rid of the field from base class and move rewriter to base class on case by case manner. Then we'd have a clean way to construct it without requiring a rewriter. Sounds like a rabbit hole that I'd avoid from this PR :-p

How about my suggestion with isa?

I'll adopt this. This is a cheap enough check that seems reasonable for pre-condition check. Although I'll refrain from being "complete" in this check because in reality, the linalg.conv_2d_* and linalg.conv3d_* is a really long list combined, with quantized, groupd and non-channel variants. I'm going to leave those other variants out and check for only simple conv2d and 3d cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly to Diego, I am suggesting to move the high level logic to vectorizeOpPrecondition. Also, like other "pre-condition" hooks, it should not require a rewriter.

I thought that it is moved to the vectorizeOpPrecondition in the PR? The check is in vectorizeLinalgOpPrecondition and the former one calls this function. Do you suggest creating a different function like vectorizeConvPrecondition, and we use it in vectorizeOpPrecondition? It is okay to me because convolution really goes with a different path.

RE verification issue: I totally agree that the verification should not depend on an IR rewriter. From what I can tell, we do not need it at all. The class needs it for StructuredGenerator, but we dont need it in the verfication at all.

// Determine whether `linalgOp` can be generated with this generator
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
resShaped = linalgOp.getDpsInitOperand(0)->get();
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return;
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
// (non-channeled convolution -> LHS and RHS both have single dimensions).
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
return;
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
if (!reduceOp)
return;
redOp = reduceOp->getName().getIdentifier();
if (!setOperKind(reduceOp))
return;
auto maybeKind = getCombinerOpKind(reduceOp);
// Typically convolution will have a `Add` CombiningKind but for i1 type it
// can get strength reduced to `OR` which is also supported. This strength
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
*maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}
reductionKind = maybeKind.value();
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
return;
break;
case Pool:
if (rhsRank != 1)
return;
break;
}
// The op is now known to be valid.
valid = true;

The valid variable is only used in assertions in few methods, e.g., depthwiseConv and conv. I think it's mainly created for sanity check, while the new codes did not take it into account. Thus, we crashed in the other place.

The code is quite old and the precondition was added later than the conv code. I think to make it in better structure, we can refactor the generator because everything is started from the generator. How about we have a static class method which returns true when the given operation is supported? That said, we move the above logic check to a static method (e.g., vectorizePrecondition) without initializing any variables.

In the construction, I'd suggest doing simple things as much as possible. And we move the assertion out of the constructor. In the context, they are moved to an initializer method. Because I'd prefer avoiding a crash in the constructor, and we can expose the failure handling to external users. (I don't know what the style is in LLVM, but it is quite common in environments where exceptions are disallowed. See https://abseil.io/tips/42 for more details.)

Thus, it can be something like

Conv1DGenerator : : public StructuredGenerator<LinalgOp, utils::IteratorType> {
// constructor only takes the rewriter and linalgop
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {}

// vectorization precond
bool/LogicalResult vectorizePrecondition(LinalgOp linalgOp) { ... }

// The initialization method
LogicalResult init() {
  // or do an assertion here.
  if (failed(vectorizedPrecondition(...))) {
    return failure();
  }
  // Initial the values for class members.
}

Does it look better structured?

if (!validateConv1DGenerator(dummyBuilder, linalgOp))
return failure();

return success();
}

// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
Expand Down Expand Up @@ -3125,10 +3144,8 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
/// kw is unrolled, w is unrolled iff dilationW > 1.
struct Conv1DGenerator
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
int dilationW)
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
strideW(strideW), dilationW(dilationW) {
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
// Determine whether `linalgOp` can be generated with this generator
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
Expand Down Expand Up @@ -3175,10 +3192,22 @@ struct Conv1DGenerator
return;
break;
}

// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can
// simply use them if present, otherwise use the default and let the generic
// conv. matcher in the ConvGenerator succeed or fail.
auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;

// The op is now known to be valid.
valid = true;
}

bool isValid() { return valid; }

/// Generate a vector implementation for:
/// ```
/// Op def: ( w, kw )
Expand Down Expand Up @@ -3889,22 +3918,21 @@ struct Conv1DGenerator
}
}
};

// Helper function to construct Conv1DGenerator
bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
Conv1DGenerator conv1dGen(rewriter, linalgOp);
return conv1dGen.isValid();
}

} // namespace

/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can
// simply use them if present, otherwise use the default and let the generic
// conv. matcher in the ConvGenerator succeed or fail.
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
Conv1DGenerator e(rewriter, op, stride, dilation);
Conv1DGenerator e(rewriter, op);
auto res = e.generateNonChanneledConv();
if (succeeded(res))
return res;
Expand All @@ -3929,9 +3957,11 @@ static FailureOr<Operation *> vectorizeConvolution(
if (!inputVecSizes.empty()) {
// Only use the input vector size corresponding to the channel dim. Other
// vector dims will be inferred from the Ops.
assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
"Not a 1D depthwise conv!");
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
!isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
return rewriter.notifyMatchFailure(
op, "Unexpected convolution: expected 1D depthwise conv");
}
size_t chDimIdx =
TypeSwitch<Operation *, size_t>(op)
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {

// -----

func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4: tensor<64x64x3x3xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%5 = tensor.empty() : tensor<1x64x56x56xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
// expected-error @+1 {{Attempted to vectorize, but failed}}
%7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 : !transform.any_op
transform.yield
}
}

// -----

func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
%pad = arith.constant 0.000000e+00 : f32
// expected-error @+1 {{Attempted to vectorize, but failed}}
Expand Down