Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -621,15 +621,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
buildMatmulOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
Expand All @@ -647,7 +647,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addAttribute("cast", cast);
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>

Expand Down
123 changes: 67 additions & 56 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
}

/// Wrapper to return the typical indexing map array attribute for MatmulOp.
static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
static SmallVector<Attribute>
getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
return llvm::map_to_vector(
getDefaultIndexingMapsForMatmul(context),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
Expand All @@ -179,12 +180,11 @@ static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
/// The result types are derived automatically if `resultTensorTypes` is none.
/// The body of the operation is filled using `regionBuilder`. All ods-gen
/// created structured operations use the method to implement their builders.
static void buildStructuredOp(
OpBuilder &b, OperationState &state,
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
RegionBuilderFn regionBuilder,
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
static void buildStructuredOp(OpBuilder &b, OperationState &state,
std::optional<TypeRange> resultTensorTypes,
ValueRange inputs, ValueRange outputs,
ArrayRef<NamedAttribute> attributes,
RegionBuilderFn regionBuilder) {
// Derive the result types if needed.
SmallVector<Type> derivedResultTypes =
resultTensorTypes.value_or(TypeRange());
Expand All @@ -196,6 +196,24 @@ static void buildStructuredOp(
state.addOperands(outputs);
state.addTypes(derivedResultTypes);

state.addAttributes(attributes);
state.addAttribute(
"operandSegmentSizes",
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));

// Create and fill the region of the structured operation.
Region &region = *state.addRegion();
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
state.attributes.getAttrs(), regionBuilder);
}

static void
buildMatmulOp(OpBuilder &b, OperationState &state,
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
RegionBuilderFn regionBuilder,
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
// Initialize indexingMaps, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
if (indexingMaps.has_value()) {
Expand All @@ -205,20 +223,11 @@ static void buildStructuredOp(
}
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
} else {
indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
}

state.addAttributes(attributes);
state.addAttribute(
"operandSegmentSizes",
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));

// Create and fill the region of the structured operation.
Region &region = *state.addRegion();
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
state.attributes.getAttrs(), regionBuilder);
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
}

/// Common parsing used for both named structured ops created by ods-gen and by
Expand Down Expand Up @@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result,
unsigned numRegionArgs,
RegionBuilderFn regionBuilder) {

SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
if (parser.parseEqual())
return failure();

if (parser.parseLSquare())
return failure();

do {
if (parser.parseAttribute(mapAttr))
return failure();
if (!isa<AffineMapAttr>(mapAttr)) {
return parser.emitError(parser.getCurrentLocation(),
"expected affine map attribute");
}
indexingMapsAttr.push_back(mapAttr);

if (parser.parseOptionalComma())
break;
} while (true);

if (parser.parseRSquare())
return failure();
}
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
}
result.addAttribute("indexing_maps",
parser.getBuilder().getArrayAttr(indexingMapsAttr));

// TODO: Enable when ods-gen supports captures.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
Expand Down Expand Up @@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,

namespace mlir {
namespace linalg {

//===----------------------------------------------------------------------===//
// MatMulOp
//===----------------------------------------------------------------------===//

SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
utils::IteratorType::parallel,
Expand All @@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() {

bool MatmulOp::hasDynamicIndexingMaps() { return true; }

/// Check if the op has broadcast and/or transpose semantic. Returns true if the
/// user defined indexing maps are not equal to default map.
/// Check if the op has broadcast and/or transpose semantic. Returns true if
/// the user defined indexing maps are not equal to default map.
bool MatmulOp::hasUserDefinedMaps() {
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
Expand Down Expand Up @@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
helper.yieldOutputs(yields);
}

/// Returns a list of AffineMap with the typical matmul indexing charactristic.
/// Returns a list of AffineMap with the typical matmul indexing
/// charactristic.
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
MLIRContext *context = this->getContext();
return getDefaultIndexingMapsForMatmul(context);
Expand All @@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
}

ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
if (parser.parseEqual())
return failure();

if (parser.parseLSquare())
return failure();

do {
if (parser.parseAttribute(mapAttr))
return failure();
if (!isa<AffineMapAttr>(mapAttr)) {
return parser.emitError(parser.getCurrentLocation(),
"expected affine map attribute");
}
indexingMapsAttr.push_back(mapAttr);

if (parser.parseOptionalComma())
break;
} while (true);

if (parser.parseRSquare())
return failure();
}
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
}
result.addAttribute("indexing_maps",
parser.getBuilder().getArrayAttr(indexingMapsAttr));

return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
MatmulOp::getRegionBuilder());
}
Expand All @@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) {
elidedAttrs);

SmallVector<Attribute, 3> indexingMaps =
getDefaultIndexingMapAttr(getContext());
getDefaultMatmulIndexingMapAttr(getContext());
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
p << " indexing_maps = [";
llvm::interleaveComma(getIndexingMaps(), p,
Expand Down
3 changes: 1 addition & 2 deletions mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
}
void {0}::print(OpAsmPrinter &p) {{
SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
"linalg.memoized_indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}
Expand Down
Loading