Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Rock/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mlir_tablegen(RockAccelTuningParamAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(RockAccelTuningParamAttrInterface.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(MLIRRockAccelTuningParamAttrInterfaceIncGen)

add_mlir_interface(RockGemmGemmWrapperInterface)
add_mlir_interface(RockGemmWrapperInterface)
add_mlir_interface(RockConvInterface)
add_mlir_interface(RockAcceptingViewOpInterface)
Expand Down
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===--------- GemmGemmSize.h - utility struct for gemm+gemm ----------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a utility struct, GemmGemmSize, that packages the sizes of
// gemm+gemm to ensure a cleaner API.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H
#define MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H

#include <cstdint>

namespace mlir {
namespace rock {

/// Structure for holding the sizes of a matrix multiplication operation.
struct GemmGemmSize {
int64_t g;
int64_t m;
int64_t k;
int64_t n;
int64_t o;

GemmGemmSize(int64_t g, int64_t m, int64_t k, int64_t n, int64_t o)
: g(g), m(m), k(k), n(n), o(o) {}

Check warning on line 31 in mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h

View check run for this annotation

Codecov / codecov/patch

mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h#L31

Added line #L31 was not covered by tests

bool operator==(const GemmGemmSize &other) {
return (g == other.g) && (m == other.m) && (k == other.k) &&
(n == other.n) && (o == other.o);
}

Check warning on line 36 in mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h

View check run for this annotation

Codecov / codecov/patch

mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h#L33-L36

Added lines #L33 - L36 were not covered by tests
};
} // end namespace rock
} // end namespace mlir
#endif // MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H
8 changes: 2 additions & 6 deletions mlir/include/mlir/Dialect/Rock/IR/Rock.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class PatternRewriter;
#include "mlir/Dialect/Rock/IR/RockTypes.h"

#include "mlir/Dialect/Rock/IR/ConvolutionDims.h"
#include "mlir/Dialect/Rock/IR/GemmGemmSize.h"
#include "mlir/Dialect/Rock/IR/GemmSize.h"

namespace mlir {
Expand All @@ -49,12 +50,6 @@ class FusionRoot : public TraitBase<ConcreteType, FusionRoot> {};
} // namespace OpTrait
} // namespace mlir

// Following ifdef could be used to change
// the attention operator to be a fused gemm-gemm
// kernel for debugging purposes. This will also
// adjust the test harness to verify the same as well
// #define ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX

namespace mlir {
namespace rock {
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -87,6 +82,7 @@ constexpr int64_t maxHardwareWorkgroupSize = 1024;

#include "mlir/Dialect/Rock/IR/RockAcceptingViewOpInterface.h"
#include "mlir/Dialect/Rock/IR/RockConvInterface.h"
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
#include "mlir/Dialect/Rock/IR/RockWriterOpInterface.h"

Expand Down
11 changes: 7 additions & 4 deletions mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define ROCK_ATTRS

include "mlir/Dialect/Rock/IR/RockBase.td"
include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td"
include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td"
include "mlir/Dialect/Rock/IR/RockTuningParamAttrInterface.td"
include "mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td"
Expand Down Expand Up @@ -63,11 +64,13 @@ def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>;
def KernelTypeConvBwdWeight : I32EnumAttrCase<"ConvBwdWeight", 2>;
def KernelTypeGemm : I32EnumAttrCase<"Gemm", 3>;
def KernelTypeAttention : I32EnumAttrCase<"Attention", 4>;
def KernelTypeGemmElementwiseGemm : I32EnumAttrCase<"GemmElementwiseGemm", 5>;

def KernelType : Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
[KernelTypeConv, KernelTypeConvBwdData,
KernelTypeConvBwdWeight, KernelTypeGemm,
KernelTypeAttention]>;
def KernelType
: Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
[KernelTypeConv, KernelTypeConvBwdData,
KernelTypeConvBwdWeight, KernelTypeGemm,
KernelTypeAttention, KernelTypeGemmElementwiseGemm]>;

/// TransformType
def PassThrough : I32EnumAttrCase<"PassThrough", 0>;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- RockGemmGemmWrapperInterface.h - ops that wrap rock.attention -*- C++
//-*-===//
//
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (c) 2025 Advanced Micro Devices INc.
//===----------------------------------------------------------------------===//
//
// This file defines RockGemmGemmWrapperInterface, which abstracts attention and
// gemm+gemm to allow code to operate on them generically.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H
#define MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H

#include "mlir/Dialect/Rock/IR/GemmGemmSize.h"
#include "mlir/IR/OpDefinition.h"

#include "mlir/Dialect/Rock/IR/RockTypes.h"

#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h.inc"

#endif // MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H
252 changes: 252 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
//===- RockGemmGemmWrapperInterface.td - ops that wrap rock.attention
//---------===//
//
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (c) 2025 Advanced Micro Devices INc.
//===----------------------------------------------------------------------===//
//
// This file defines RockGemmGemmWrapperInterface, which abstracts attention and
// gemm+gemm and friends (conv+gemm, ...) to allow code to operate on them
// generically.
//
//===----------------------------------------------------------------------===//

#ifndef ROCK_GEMM_GEMM_WRAPPER_INTERFACE
#define ROCK_GEMM_GEMM_WRAPPER_INTERFACE

include "mlir/IR/OpBase.td"

def RockGemmGemmWrapperInterface : OpInterface<"RockGemmGemmWrapperInterface"> {
let description = [{
Interface to abstract away gemm+gemm-wrapping operators in the rock dialect,
which mainly include attention and gemm+gemm and friends that can be implemented
with flash attention.

This should include functions to get common attributes.
}];
let cppNamespace = "::mlir::rock";

let methods = [
InterfaceMethod<
/*desc=*/[{
Return the KernelType of this op
}],
/*retType=*/"::mlir::rock::KernelType",
/*methodName=*/"getKernelType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the arch string of this op
}],
/*retType=*/"StringRef",
/*methodName=*/"getArch",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the OpOperand that corresponds to the operand argument
that corresponds to the output result of the operation.
}],
/*retType=*/"OpOperand *",
/*methodName=*/"getOutArgument",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the size of the matrix multiplication that this op will eventually
perform.
}],
/*retType=*/"::mlir::rock::GemmGemmSize",
/*methodName=*/"getGemmGemmSize",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the element type of [what will become] matrix A for this operation.
}],
/*retType=*/"::mlir::Type",
/*methodName=*/"getAType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the element type of [what will become] matrix B for this operation.
}],
/*retType=*/"::mlir::Type",
/*methodName=*/"getBType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the element type of [what will become] matrix C for this operation.
}],
/*retType=*/"::mlir::Type",
/*methodName=*/"getCType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the element type of [what will become] output matrix for this operation.
}],
/*retType=*/"::mlir::Type",
/*methodName=*/"getOutType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the whether matrix A is transposed.
}],
/*retType=*/"bool",
/*methodName=*/"getTransposedA",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the whether matrix B is transposed.
}],
/*retType=*/"bool",
/*methodName=*/"getTransposedB",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the whether matrix C is transposed.
}],
/*retType=*/"bool",
/*methodName=*/"getTransposedC",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the whether output matrix is transposed.
}],
/*retType=*/"bool",
/*methodName=*/"getTransposedOut",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return the features attribute of this op.
}],
/*retType=*/"::mlir::rock::GemmFeatures",
/*methodName=*/"getGemmFeatures",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getFeatures();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the optional number of Compute Units the GPU provides.
}],
/*retType=*/"std::optional<uint32_t>",
/*methodName=*/"getNumCU",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/ ""
>,

InterfaceMethod<
/*desc=*/[{
Set the tuning parameters attribute of the first GEMM
}],
/*retType=*/"void",
/*methodName=*/"setGemm0ParamsAttr",
/*args=*/(ins "::mlir::Attribute":$params),
/*methodBody=*/"",
/*defaultImplementation=*/[{
$_op->setAttr($_op.getParams0AttrName(), params);
}]
>,
InterfaceMethod<
/*desc=*/[{
Set the tuning parameters attribute of the second GEMM
}],
/*retType=*/"void",
/*methodName=*/"setGemm1ParamsAttr",
/*args=*/(ins "::mlir::Attribute":$params),
/*methodBody=*/"",
/*defaultImplementation=*/[{
$_op->setAttr($_op.getParams1AttrName(), params);
}]
>,
InterfaceMethod<
/*desc=*/[{
Get the tuning parameters attribute of the first GEMM
}],
/*retType=*/"std::optional<RockTuningParamAttrInterface>",
/*methodName=*/"getGemm0Params",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getParams0();
}]
>,
InterfaceMethod<
/*desc=*/[{
Get the tuning parameters attribute of the second GEMM
}],
/*retType=*/"std::optional<RockTuningParamAttrInterface>",
/*methodName=*/"getGemm1Params",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getParams1();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the index of the elementwise region argument that comes from the first GEMM.
}],
/*retType=*/"uint32_t",
/*methodName=*/"getFirstGemmIndex",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/ ""
>,

// TODO: more methods here as needed
];

let verify = [{
auto concreteOp = ::mlir::cast<ConcreteOp>($_op);
if ($_op->getNumResults() == 1) {
if ($_op->getResult(0).getType() !=
concreteOp.getOutArgument()->get().getType()) {
return $_op->emitOpError("result type must match output argument type");
}
}
return ::mlir::success();
}];
}

#endif // ROCK_GEMM_GEMM_WRAPPER_INTERFACE
Loading
Loading