Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)

add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)

add_subdirectory(TransformOps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"

//===----------------------------------------------------------------------===//
// X86Vector Transform Operations
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace x86vector {
void registerTransformDialectExtension(DialectRegistry &registry);

} // namespace x86vector
} // namespace mlir

#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef X86VECTOR_TRANSFORM_OPS
#define X86VECTOR_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/IR/RegionKindInterface.td"

def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector contract operation can be lowered to a FMA.
}];

let assemblyFormat = "attr-dict";
}

def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector contract operation can be lowered to a BF16/Int8 dot-product.
}];

let assemblyFormat = "attr-dict";
}


#endif // X86VECTOR_TRANSFORM_OPS

11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

#include "mlir/IR/Value.h"

#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {

class ImplicitLocOpBuilder;
Expand Down Expand Up @@ -79,6 +83,13 @@ struct MaskHelper {
}
};

//===----------------------------------------------------------------------===//

void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);

void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/X86Vector/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIRX86VectorTransformOps
X86VectorTransformOps.cpp

DEPENDS
MLIRX86VectorTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRTransformDialectUtils
MLIRX86VectorDialect
MLIRX86VectorTransforms
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
//--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"

#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"

using namespace mlir;
using namespace mlir::x86vector;
using namespace mlir::transform;

void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateVectorContractToFMAPatterns(patterns);
}

void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class X86VectorTransformDialectExtension
: public transform::TransformDialectExtension<
X86VectorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
X86VectorTransformDialectExtension)

X86VectorTransformDialectExtension() {
declareGeneratedDialect<x86vector::X86VectorDialect>();
declareGeneratedDialect<LLVM::LLVMDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"

void mlir::x86vector::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<X86VectorTransformDialectExtension>();
}
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
VectorContractToPackedTypeDotProduct.cpp

LINK_LIBS PUBLIC
MLIRArithDialect
Expand Down
99 changes: 99 additions & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//===- VectorContractToFMA.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;

// Implements outer product contraction as a sequence of broadcast and
// FMA operations.
//
// For example - for F32 type:
// ```
// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
// ```
// to
// ```
// vector.broadcast %lhs to <16xf32>
// vector.fma vector<16xf32>
// ```
struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {

if (contractOp.getKind() != vector::CombiningKind::ADD) {
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind");
}

VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isF32())
return rewriter.notifyMatchFailure(contractOp,
"Only F32 lowering is supported.");
if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; }))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, this would also work with B as the broadcasted element, no?

return rewriter.notifyMatchFailure(
contractOp, "Expects one for all dimensions of LHS");

VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
llvm::SmallVector<int64_t> dimsRhs;
llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this be count(rhsShape, 1) == rank(rhsShape) - 1?

Would avoid allocation and copies.

[](int64_t dim) { return dim != 1; });
if (dimsRhs.size() != 1)
return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");

VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
assert(accTy && "Invalid accumulator");
ArrayRef<int64_t> accShape = accTy.getShape();
llvm::SmallVector<int64_t> dimsAcc;
llvm::copy_if(accShape, std::back_inserter(dimsAcc),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be just accShape == rhsShape?

[](int64_t dim) { return dim != 1; });
if (dimsAcc.size() != 1)
return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");

// Lowers vector.contract into a broadcast+FMA sequence.
auto loc = contractOp.getLoc();
auto castLhs = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
contractOp.getLhs());
auto castRhs = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
contractOp.getRhs());
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
contractOp.getAcc());
auto broadcastLhs = vector::BroadcastOp::create(
rewriter, loc, castRhs.getResult().getType(), castLhs);
auto fma =
vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);

rewriter.replaceOp(contractOp, castFma);

return success();
}
};

void x86vector::populateVectorContractToFMAPatterns(
RewritePatternSet &patterns) {
patterns.add<VectorContractToFMA>(patterns.getContext());
}
Loading
Loading