77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Dialect/Rock/IR/Rock.h"
10- #include " mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
1110#include " mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
11+ #include " mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
1212#include " mlir/Dialect/Rock/IR/RockTypes.h"
1313#include " mlir/Dialect/Rock/utility/math.h"
1414
@@ -2081,48 +2081,38 @@ LogicalResult BlockwiseFillOp::verify() {
20812081// ===-----------------------------------------------------===//
20822082
20832083OpOperand *GemmElementwiseGemmOp::getOutArgument () {
2084- return &(*this )->getOpOperand (getNumOperands ()- 1 );
2084+ return &(*this )->getOpOperand (getNumOperands () - 1 );
20852085}
20862086
2087- Type GemmElementwiseGemmOp::getOutType () {
2088- return getOut ().getType ();
2089- }
2087+ Type GemmElementwiseGemmOp::getOutType () { return getOut ().getType (); }
20902088
2091- Type GemmElementwiseGemmOp::getAType () {
2092- return getA ().getType ();
2093- }
2089+ Type GemmElementwiseGemmOp::getAType () { return getA ().getType (); }
20942090
2095- Type GemmElementwiseGemmOp::getBType () {
2096- return getB ().getType ();
2097- }
2091+ Type GemmElementwiseGemmOp::getBType () { return getB ().getType (); }
20982092
2099- Type GemmElementwiseGemmOp::getCType () {
2100- return getC ().getType ();
2101- }
2093+ Type GemmElementwiseGemmOp::getCType () { return getC ().getType (); }
21022094
2103- bool GemmElementwiseGemmOp::getTransposedA () {
2104- return getATransposed ();
2105- }
2095+ bool GemmElementwiseGemmOp::getTransposedA () { return getATransposed (); }
21062096
2107- bool GemmElementwiseGemmOp::getTransposedB () {
2108- return getBTransposed ();
2109- }
2097+ bool GemmElementwiseGemmOp::getTransposedB () { return getBTransposed (); }
21102098
2111- bool GemmElementwiseGemmOp::getTransposedC () {
2112- return getCTransposed ();
2113- }
2099+ bool GemmElementwiseGemmOp::getTransposedC () { return getCTransposed (); }
21142100
2115- bool GemmElementwiseGemmOp::getTransposedOut () {
2116- return getOTransposed ();
2117- }
2101+ bool GemmElementwiseGemmOp::getTransposedOut () { return getOTransposed (); }
21182102
2119- KernelType GemmElementwiseGemmOp::getKernelType () { return KernelType::GemmElementwiseGemm; }
2103+ KernelType GemmElementwiseGemmOp::getKernelType () {
2104+ return KernelType::GemmElementwiseGemm;
2105+ }
21202106
2121- uint32_t GemmElementwiseGemmOp::getFirstGemmIndex () { return getFirstGemmIdx (); }
2107+ uint32_t GemmElementwiseGemmOp::getFirstGemmIndex () {
2108+ return getFirstGemmIdx ();
2109+ }
21222110
21232111GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize () {
2124- ShapedType typeA = getA ().getType (), typeB = getB ().getType (), typeC = getC ().getType ();
2125- ArrayRef<int64_t > dimsA = typeA.getShape (), dimsB = typeB.getShape (), dimsC = typeC.getShape ();
2112+ ShapedType typeA = getA ().getType (), typeB = getB ().getType (),
2113+ typeC = getC ().getType ();
2114+ ArrayRef<int64_t > dimsA = typeA.getShape (), dimsB = typeB.getShape (),
2115+ dimsC = typeC.getShape ();
21262116 int64_t offsetA = dimsA.size () == 2 ? 0 : 1 ,
21272117 offsetB = dimsB.size () == 2 ? 0 : 1 ,
21282118 offsetC = dimsC.size () == 2 ? 0 : 1 ;
@@ -2134,25 +2124,28 @@ GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() {
21342124 return GemmGemmSize (g, m, k, n, o);
21352125}
21362126
2137- static LogicalResult verifyAttentionOp (RockGemmGemmWrapperInterface op,
2127+ static LogicalResult verifyAttentionOp (RockGemmGemmWrapperInterface op,
21382128 Value currentSeqLen) {
21392129 ShapedType qType = cast<ShapedType>(op.getAType ());
21402130 int64_t qBatchDim = qType.getShape ().size () == 3 ? qType.getShape ()[0 ] : 1 ;
21412131 ArrayRef<int64_t > qLastDims = qType.getShape ().slice (qType.getRank () - 2 );
2142- auto [queryM, queryK] = op.getTransposedA () ? std::tuple{qLastDims[1 ], qLastDims[0 ]}
2143- : std::tuple{qLastDims[0 ], qLastDims[1 ]};
2132+ auto [queryM, queryK] = op.getTransposedA ()
2133+ ? std::tuple{qLastDims[1 ], qLastDims[0 ]}
2134+ : std::tuple{qLastDims[0 ], qLastDims[1 ]};
21442135
21452136 ShapedType kType = cast<ShapedType>(op.getBType ());
21462137 int64_t kBatchDim = kType .getShape ().size () == 3 ? kType .getShape ()[0 ] : 1 ;
21472138 ArrayRef<int64_t > kLastDims = kType .getShape ().slice (kType .getRank () - 2 );
2148- auto [keyK, keyN] = op.getTransposedB () ? std::tuple{kLastDims [1 ], kLastDims [0 ]}
2149- : std::tuple{kLastDims [0 ], kLastDims [1 ]};
2139+ auto [keyK, keyN] = op.getTransposedB ()
2140+ ? std::tuple{kLastDims [1 ], kLastDims [0 ]}
2141+ : std::tuple{kLastDims [0 ], kLastDims [1 ]};
21502142
21512143 ShapedType vType = cast<ShapedType>(op.getCType ());
21522144 int64_t vBatchDim = vType.getShape ().size () == 3 ? vType.getShape ()[0 ] : 1 ;
21532145 ArrayRef<int64_t > vLastDims = vType.getShape ().slice (vType.getRank () - 2 );
2154- auto [valueK, valueN] = op.getTransposedC () ? std::tuple{vLastDims[1 ], vLastDims[0 ]}
2155- : std::tuple{vLastDims[0 ], vLastDims[1 ]};
2146+ auto [valueK, valueN] = op.getTransposedC ()
2147+ ? std::tuple{vLastDims[1 ], vLastDims[0 ]}
2148+ : std::tuple{vLastDims[0 ], vLastDims[1 ]};
21562149
21572150 if (qBatchDim != kBatchDim || kBatchDim != vBatchDim) {
21582151 return op.emitError (" Batch dimensions do not match" );
@@ -2171,7 +2164,7 @@ static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
21712164 ArrayRef<int64_t > oLastDims = oType.getShape ().slice (oType.getRank () - 2 );
21722165 auto [outputSeqLen, outputHeadDim] =
21732166 op.getTransposedOut () ? std::tuple{oLastDims[1 ], oLastDims[0 ]}
2174- : std::tuple{oLastDims[0 ], oLastDims[1 ]};
2167+ : std::tuple{oLastDims[0 ], oLastDims[1 ]};
21752168
21762169 if (qType.getShape ().size () != oType.getShape ().size ()) {
21772170 return op.emitError (" Number of dimensions do not match (Q and Output)" );
@@ -2223,48 +2216,34 @@ void GemmElementwiseGemmOp::getEffects(
22232216// ===-----------------------------------------------------===//
22242217
22252218OpOperand *AttentionOp::getOutArgument () {
2226- return &(*this )->getOpOperand (getNumOperands ()- 1 );
2219+ return &(*this )->getOpOperand (getNumOperands () - 1 );
22272220}
22282221
2229- Type AttentionOp::getOutType () {
2230- return getOut ().getType ();
2231- }
2222+ Type AttentionOp::getOutType () { return getOut ().getType (); }
22322223
2233- Type AttentionOp::getAType () {
2234- return getQueries ().getType ();
2235- }
2224+ Type AttentionOp::getAType () { return getQueries ().getType (); }
22362225
2237- Type AttentionOp::getBType () {
2238- return getKeys ().getType ();
2239- }
2226+ Type AttentionOp::getBType () { return getKeys ().getType (); }
22402227
2241- Type AttentionOp::getCType () {
2242- return getValues ().getType ();
2243- }
2228+ Type AttentionOp::getCType () { return getValues ().getType (); }
22442229
2245- bool AttentionOp::getTransposedA () {
2246- return getQTransposed ();
2247- }
2230+ bool AttentionOp::getTransposedA () { return getQTransposed (); }
22482231
2249- bool AttentionOp::getTransposedB () {
2250- return getKTransposed ();
2251- }
2232+ bool AttentionOp::getTransposedB () { return getKTransposed (); }
22522233
2253- bool AttentionOp::getTransposedC () {
2254- return getVTransposed ();
2255- }
2234+ bool AttentionOp::getTransposedC () { return getVTransposed (); }
22562235
2257- bool AttentionOp::getTransposedOut () {
2258- return getOTransposed ();
2259- }
2236+ bool AttentionOp::getTransposedOut () { return getOTransposed (); }
22602237
22612238KernelType AttentionOp::getKernelType () { return KernelType::Attention; }
22622239
22632240uint32_t AttentionOp::getFirstGemmIndex () { return getFirstGemmIdx (); }
22642241
22652242GemmGemmSize AttentionOp::getGemmGemmSize () {
2266- ShapedType typeA = getQueries ().getType (), typeB = getKeys ().getType (), typeC = getValues ().getType ();
2267- ArrayRef<int64_t > dimsA = typeA.getShape (), dimsB = typeB.getShape (), dimsC = typeC.getShape ();
2243+ ShapedType typeA = getQueries ().getType (), typeB = getKeys ().getType (),
2244+ typeC = getValues ().getType ();
2245+ ArrayRef<int64_t > dimsA = typeA.getShape (), dimsB = typeB.getShape (),
2246+ dimsC = typeC.getShape ();
22682247 int64_t offsetA = dimsA.size () == 2 ? 0 : 1 ,
22692248 offsetB = dimsB.size () == 2 ? 0 : 1 ,
22702249 offsetC = dimsC.size () == 2 ? 0 : 1 ;
0 commit comments