@@ -502,9 +502,9 @@ getElementwiseRegion(Value input, OpBuilder ®ionBuilder, Block *block,
502502
503503 FailureOr<mlir::tosa::MatMulOp> maybeMatMul = failure ();
504504 for (auto operand : op->getOperands ()) {
505- auto [result, maybeSubTreeMatMul] =
506- getElementwiseRegion<OpT>( operand, regionBuilder, block, elementwiseArgs,
507- loc, doRewrite, recDepth + 1 );
505+ auto [result, maybeSubTreeMatMul] = getElementwiseRegion<OpT>(
506+ operand, regionBuilder, block, elementwiseArgs, loc, doRewrite ,
507+ recDepth + 1 );
508508 mapper.map (operand, result);
509509 newOperands.push_back (result);
510510 if (succeeded (maybeSubTreeMatMul)) {
@@ -1049,7 +1049,8 @@ struct GemmElementwiseGemmRewritePattern
10491049 SmallVector<Value> vec;
10501050 FailureOr<tosa::MatMulOp> maybeFirstMatMul;
10511051 std::tie (std::ignore, maybeFirstMatMul) =
1052- getElementwiseRegion<rock::GemmElementwiseGemmOp>(op.getA (), b, nullptr , vec);
1052+ getElementwiseRegion<rock::GemmElementwiseGemmOp>(op.getA (), b, nullptr ,
1053+ vec);
10531054
10541055 if (succeeded (maybeFirstMatMul))
10551056 LLVM_DEBUG (llvm::dbgs ()
@@ -1072,8 +1073,9 @@ struct GemmElementwiseGemmRewritePattern
10721073 SmallVector<Value> elementwiseOtherArgs;
10731074
10741075 FailureOr<tosa::MatMulOp> maybeFirstMatMul;
1075- std::tie (std::ignore, maybeFirstMatMul) = getElementwiseRegion<rock::GemmElementwiseGemmOp>(
1076- op.getA (), rewriter, nullptr , elementwiseOtherArgs);
1076+ std::tie (std::ignore, maybeFirstMatMul) =
1077+ getElementwiseRegion<rock::GemmElementwiseGemmOp>(
1078+ op.getA (), rewriter, nullptr , elementwiseOtherArgs);
10771079 // This is guranteed by the matcher
10781080 tosa::MatMulOp firstMatMulOp = maybeFirstMatMul.value ();
10791081 IntegerAttr numCUAttr =
@@ -1099,8 +1101,9 @@ struct GemmElementwiseGemmRewritePattern
10991101 rewriter.setInsertionPointToStart (preSecondGemmElemwiseBlock);
11001102 Value res;
11011103 std::tie (res, maybeMatMul) =
1102- getElementwiseRegion<rock::GemmElementwiseGemmOp>(op.getA (), rewriter, preSecondGemmElemwiseBlock,
1103- elementwiseOtherArgs, loc, true );
1104+ getElementwiseRegion<rock::GemmElementwiseGemmOp>(
1105+ op.getA (), rewriter, preSecondGemmElemwiseBlock,
1106+ elementwiseOtherArgs, loc, true );
11041107 RankedTensorType resTensorType = cast<RankedTensorType>(res.getType ());
11051108 MemRefType resMemRefType = MemRefType::get (
11061109 resTensorType.getShape (), resTensorType.getElementType ());
@@ -1388,7 +1391,7 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
13881391 SmallVector<Value> vec;
13891392 FailureOr<tosa::MatMulOp> maybeFirstMatMul;
13901393 std::tie (std::ignore, maybeFirstMatMul) =
1391- getElementwiseRegion<rock::AttentionOp>(softmaxInput, b, nullptr , vec);
1394+ getElementwiseRegion<rock::AttentionOp>(softmaxInput, b, nullptr , vec);
13921395
13931396 if (succeeded (maybeFirstMatMul)) {
13941397 TypedValue<TensorType> matC = maybeFirstMatMul.value ().getC ();
@@ -1425,8 +1428,9 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
14251428 SmallVector<Value> elementwiseOtherArgs;
14261429
14271430 FailureOr<tosa::MatMulOp> maybeFirstMatMul;
1428- std::tie (std::ignore, maybeFirstMatMul) = getElementwiseRegion<rock::AttentionOp>(
1429- softmaxInput, rewriter, nullptr , elementwiseOtherArgs);
1431+ std::tie (std::ignore, maybeFirstMatMul) =
1432+ getElementwiseRegion<rock::AttentionOp>(softmaxInput, rewriter, nullptr ,
1433+ elementwiseOtherArgs);
14301434 // This is guranteed by the matcher
14311435 tosa::MatMulOp firstMatMulOp = maybeFirstMatMul.value ();
14321436 IntegerAttr numCUAttr =
@@ -1457,9 +1461,9 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
14571461 PatternRewriter::InsertionGuard guard (rewriter);
14581462 rewriter.setInsertionPointToStart (preSoftmaxElemwiseBlock);
14591463 Value res;
1460- std::tie (res, maybeMatMul) =
1461- getElementwiseRegion<rock::AttentionOp>( softmaxInput, rewriter, preSoftmaxElemwiseBlock,
1462- elementwiseOtherArgs, loc, true );
1464+ std::tie (res, maybeMatMul) = getElementwiseRegion<rock::AttentionOp>(
1465+ softmaxInput, rewriter, preSoftmaxElemwiseBlock, elementwiseOtherArgs ,
1466+ loc, true );
14631467 RankedTensorType resTensorType = cast<RankedTensorType>(res.getType ());
14641468 MemRefType resMemRefType = MemRefType::get (
14651469 resTensorType.getShape (), resTensorType.getElementType ());
0 commit comments