Skip to content

Commit e811d48

Browse files
Tests reviews + clean up more
1 parent c82c3d3 commit e811d48

File tree

4 files changed

+75
-107
lines changed

4 files changed

+75
-107
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
115115
//===----------------------------------------------------------------------===//
116116

117117
/// Given a linalg `op` this function returns true if it is a convolution op of
118-
/// type `ConvOpTy` and populate `dilations` and `strides` arguments.
118+
/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred
119+
/// from the indexing maps.
119120
template <typename ConvOpTy>
120121
bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
121122
SmallVector<int64_t> *strides);

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -249,18 +249,10 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
249249
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
250250
? TypeRange(ValueRange(outputs))
251251
: TypeRange{};
252-
LinalgOp namedOp;
253-
if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
254-
std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
255-
std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
256-
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
257-
inputs, outputs);
258-
} else {
259-
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
260-
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
261-
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
262-
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
263-
}
252+
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
253+
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
254+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
255+
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
264256
return namedOp;
265257
}
266258

0 commit comments

Comments
 (0)