Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -621,15 +621,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
buildMatmulOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
Expand All @@ -647,7 +647,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addAttribute("cast", cast);
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>

Expand Down
123 changes: 67 additions & 56 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
}

/// Wrapper to return the typical indexing map array attribute for MatmulOp.
static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
static SmallVector<Attribute>
getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
return llvm::map_to_vector(
getDefaultIndexingMapsForMatmul(context),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
Expand All @@ -179,12 +180,11 @@ static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
/// The result types are derived automatically if `resultTensorTypes` is none.
/// The body of the operation is filled using `regionBuilder`. All ods-gen
/// created structured operations use the method to implement their builders.
static void buildStructuredOp(
OpBuilder &b, OperationState &state,
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
RegionBuilderFn regionBuilder,
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
static void buildStructuredOp(OpBuilder &b, OperationState &state,
std::optional<TypeRange> resultTensorTypes,
ValueRange inputs, ValueRange outputs,
ArrayRef<NamedAttribute> attributes,
RegionBuilderFn regionBuilder) {
// Derive the result types if needed.
SmallVector<Type> derivedResultTypes =
resultTensorTypes.value_or(TypeRange());
Expand All @@ -196,6 +196,24 @@ static void buildStructuredOp(
state.addOperands(outputs);
state.addTypes(derivedResultTypes);

state.addAttributes(attributes);
state.addAttribute(
"operandSegmentSizes",
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));

// Create and fill the region of the structured operation.
Region &region = *state.addRegion();
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
state.attributes.getAttrs(), regionBuilder);
}

static void
buildMatmulOp(OpBuilder &b, OperationState &state,
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
RegionBuilderFn regionBuilder,
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
// Initialize indexingMaps, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
if (indexingMaps.has_value()) {
Expand All @@ -205,20 +223,11 @@ static void buildStructuredOp(
}
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
} else {
indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
}

state.addAttributes(attributes);
state.addAttribute(
"operandSegmentSizes",
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));

// Create and fill the region of the structured operation.
Region &region = *state.addRegion();
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
state.attributes.getAttrs(), regionBuilder);
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
}

/// Common parsing used for both named structured ops created by ods-gen and by
Expand Down Expand Up @@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result,
unsigned numRegionArgs,
RegionBuilderFn regionBuilder) {

SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
if (parser.parseEqual())
return failure();

if (parser.parseLSquare())
return failure();

do {
if (parser.parseAttribute(mapAttr))
return failure();
if (!isa<AffineMapAttr>(mapAttr)) {
return parser.emitError(parser.getCurrentLocation(),
"expected affine map attribute");
}
indexingMapsAttr.push_back(mapAttr);

if (parser.parseOptionalComma())
break;
} while (true);

if (parser.parseRSquare())
return failure();
}
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
}
result.addAttribute("indexing_maps",
parser.getBuilder().getArrayAttr(indexingMapsAttr));

// TODO: Enable when ods-gen supports captures.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
Expand Down Expand Up @@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,

namespace mlir {
namespace linalg {

//===----------------------------------------------------------------------===//
// MatMulOp
//===----------------------------------------------------------------------===//

SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
utils::IteratorType::parallel,
Expand All @@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() {

bool MatmulOp::hasDynamicIndexingMaps() { return true; }

/// Check if the op has broadcast and/or transpose semantic. Returns true if the
/// user defined indexing maps are not equal to default map.
/// Check if the op has broadcast and/or transpose semantic. Returns true if
/// the user defined indexing maps are not equal to default map.
bool MatmulOp::hasUserDefinedMaps() {
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
Expand Down Expand Up @@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
helper.yieldOutputs(yields);
}

/// Returns a list of AffineMap with the typical matmul indexing charactristic.
/// Returns a list of AffineMap with the typical matmul indexing
/// charactristic.
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
MLIRContext *context = this->getContext();
return getDefaultIndexingMapsForMatmul(context);
Expand All @@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
}

ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
if (parser.parseEqual())
return failure();

if (parser.parseLSquare())
return failure();

do {
if (parser.parseAttribute(mapAttr))
return failure();
if (!isa<AffineMapAttr>(mapAttr)) {
return parser.emitError(parser.getCurrentLocation(),
"expected affine map attribute");
}
indexingMapsAttr.push_back(mapAttr);

if (parser.parseOptionalComma())
break;
} while (true);

if (parser.parseRSquare())
return failure();
}
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
}
result.addAttribute("indexing_maps",
parser.getBuilder().getArrayAttr(indexingMapsAttr));

return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
MatmulOp::getRegionBuilder());
}
Expand All @@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) {
elidedAttrs);

SmallVector<Attribute, 3> indexingMaps =
getDefaultIndexingMapAttr(getContext());
getDefaultMatmulIndexingMapAttr(getContext());
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
p << " indexing_maps = [";
llvm::interleaveComma(getIndexingMaps(), p,
Expand Down
12 changes: 8 additions & 4 deletions mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128]
// CHECK-NEXT: return %[[RES]]
%1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>)
Expand All @@ -62,7 +63,8 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
// CHECK-NEXT: return %[[RES]]
Expand Down Expand Up @@ -113,7 +115,8 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec
// CHECK-SAME: ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
// CHECK-NEXT: return %[[RES]]
Expand All @@ -140,7 +143,8 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
// CHECK-NEXT: return %[[RES]]
Expand Down
3 changes: 1 addition & 2 deletions mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
}
void {0}::print(OpAsmPrinter &p) {{
SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
"linalg.memoized_indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}
Expand Down
Loading