Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions include/gc/Analysis/MatmulConfigAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,14 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
if (llvm::isa<linalg::ContractionOpInterface>(linalgOp.getOperation())) {
return getContractionOpOperandDimType(linalgOp);
} else if (linalgx::isGenericPackedMatmulOp(
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D) ||
llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D)) {
return SmallVector<SmallVector<DimType>>{
SmallVector<DimType>{DimType::M, DimType::K},
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
DimType::K},
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
} else if (linalgx::isGenericPackedMatmulOp(
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM4D) ||
llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM4D)) {
return SmallVector<SmallVector<DimType>>{
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
Expand Down
208 changes: 0 additions & 208 deletions include/gc/Dialect/Linalgx/LinalgxStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -104,212 +104,4 @@ def Linalgx_SigmoidOp : LinalgxStructuredBase_Op<"sigmoid",
}];
}

def Linalgx_Mm2DVnniOp
: LinalgxStructuredBase_Op<"mm2d_vnni", [AttrSizedOperandSegments]> {
let summary = "Transposed matmul with 2d input and vnni packed weights";
let description = [{
Supported format: A[M, K] * B[N0, K0, k, n, v] -> C[M, N], with:
N = N0 * n
K = K0 * k * v; v = (2, 4)
}];
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$outputs);
let results = (outs Variadic<TensorOrMemref>:$results);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins
"TypeRange":$resultTensorTypes,
"ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, Mm2DVnniOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static unsigned getNumRegionArgs() { return 3; }
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}

// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}
}];
}

def Linalgx_Mm4DVnniOp
: LinalgxStructuredBase_Op<"mm4d_vnni", [AttrSizedOperandSegments]> {
let summary = "Transposed matmul with 4d blocking input and vnni packed weights";
let description = [{
Supported format: A[M, K, m, k] * B[N, K, k0, n, v] -> C[M, N, m, n], with:
k = k0 * v; v = (2, 4)
}];
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$outputs);
let results = (outs Variadic<TensorOrMemref>:$results);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins
"TypeRange":$resultTensorTypes,
"ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, Mm4DVnniOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static unsigned getNumRegionArgs() { return 3; }
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}

// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}
}];
}

def Linalgx_BatchReduceMatmulVnniOp
: LinalgxStructuredBase_Op<"batch_reduce_matmul_vnni", [AttrSizedOperandSegments]> {
let summary = "Batch reduced matmul with 3d batch input and vnni packed weights";
let description = [{
Supported format: A[B, M, K] * B[B, k, N, v] -> C[M, N], with:
K = k * v; v = (2, 4)
}];
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$outputs);
let results = (outs Variadic<TensorOrMemref>:$results);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins
"TypeRange":$resultTensorTypes,
"ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, BatchReduceMatmulVnniOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static unsigned getNumRegionArgs() { return 3; }
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}

// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}
}];
}

def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
[AttrSizedOperandSegments, LinalgContractionOpInterface]> {
let summary = "Batch matmul with variable batch dims";
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$outputs);
let results = (outs Variadic<TensorOrMemref>:$results);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins
"TypeRange":$resultTensorTypes,
"ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, MultiBatchMatmulOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static unsigned getNumRegionArgs() { return 3; }
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}

// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}
}];
}

#endif // LINALGX_STRUCTURED_OPS
26 changes: 25 additions & 1 deletion include/gc/Dialect/Linalgx/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace mlir {
namespace linalgx {

/// @brief enum of type of matmul packing
enum class PackingType {
enum class PackingType : int {
MM4D = 0, // MKmk x NKkn
VNNI_MM2D, // MK x NKknV
VNNI_MM4D, // MKmk x NKknV
Expand All @@ -43,6 +43,30 @@ makeGenericPackedMatmulOp(OpBuilder &builder, Location loc, PackingType opType,
/// @return true if op is a generic packed matmul Op
bool isGenericPackedMatmulOp(Operation *op, PackingType opType);

template <typename... Args>
inline bool isGenericPackedMatmulOp(Operation *op, PackingType first,
Args... args) {
return isGenericPackedMatmulOp(op, first) ||
isGenericPackedMatmulOp(op, args...);
}

/// @brief identify a generic packed matmul Op based on any PackingType
/// @param op the op
/// @return true if op is a generic packed matmul Op
template <int T, int N> inline bool isAnyGenericPackedMatmulOp(Operation *op) {
return isGenericPackedMatmulOp(op, (PackingType)N) ||
isAnyGenericPackedMatmulOp<T + 1, N>(op);
}
constexpr int NUM_ALL_TYPES = (int)PackingType::NUM_TYPES;
template <>
inline bool
isAnyGenericPackedMatmulOp<NUM_ALL_TYPES, NUM_ALL_TYPES>(Operation *op) {
return false;
}
inline bool isAnyGenericPackedMatmulOp(Operation *op) {
return isAnyGenericPackedMatmulOp<0, NUM_ALL_TYPES>(op);
}

/// @brief identify a matmul Op based on ContractionOp and PackingType
/// @param op the op
/// @return true if op is a matmul Op
Expand Down
Loading