@@ -905,6 +905,23 @@ static bool isElementwiseOp(Operation *op) {
905905struct AttentionRewritePattern : public OpRewritePattern <tosa::MatMulOp> {
906906 using OpRewritePattern<tosa::MatMulOp>::OpRewritePattern;
907907
908+ FailureOr<Value> getValueNonReshapeOpNonBroadcast (Value val) const {
909+ while (val.getDefiningOp () &&
910+ (val.getDefiningOp <tensor::CollapseShapeOp>() ||
911+ val.getDefiningOp <tensor::ExpandShapeOp>() ||
912+ val.getDefiningOp <tosa::TransposeOp>() ||
913+ val.getDefiningOp <tosa::AddOp>())) {
914+ if (val.getDefiningOp <tosa::AddOp>()) {
915+ auto maybeBroadcast = addBroadcast (val);
916+ if (failed (maybeBroadcast))
917+ return failure ();
918+ val = maybeBroadcast.value ();
919+ } else
920+ val = val.getDefiningOp ()->getOperand (0 );
921+ }
922+ return val;
923+ }
924+
908925 Value getValueNonReshapeOp (Value val) const {
909926 while (val.getDefiningOp () &&
910927 (val.getDefiningOp <tensor::CollapseShapeOp>() ||
@@ -1004,8 +1021,35 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
10041021 if (failed (maybeNonZero2))
10051022 return failure ();
10061023
1024+ // check that the right dimensions are broadcasted
1025+ auto beforeBroadcastShape =
1026+ dyn_cast<ShapedType>(maybeNonZero2->getType ());
1027+ if (beforeBroadcastShape) {
1028+ auto shape = beforeBroadcastShape.getShape ();
1029+ if (beforeBroadcastShape.getRank () > 2 &&
1030+ !llvm::all_of (shape.slice (2 ), [](int32_t v) { return v == 1 ; }))
1031+ return failure ();
1032+ } else {
1033+ return failure ();
1034+ }
1035+
10071036 Value currentSeqLen = getValueNonReshapeOp (maybeNonZero2.value ());
10081037 Value result = select.getOnFalse ();
1038+
1039+ // currentSeqLen must be of i32 type
1040+ auto currentSeqLenShape = dyn_cast<ShapedType>(currentSeqLen.getType ());
1041+ if (!currentSeqLenShape ||
1042+ !currentSeqLenShape.getElementType ().isInteger (32 ))
1043+ return failure ();
1044+
1045+ // we'll check now if currentSeqLen comes from a block argument
1046+ FailureOr<Value> mustBeBlockArg =
1047+ getValueNonReshapeOpNonBroadcast (currentSeqLen);
1048+
1049+ if (failed (mustBeBlockArg) ||
1050+ !isa<BlockArgument>(mustBeBlockArg.value ()))
1051+ return failure ();
1052+
10091053 return std::make_pair (result, currentSeqLen);
10101054 }
10111055 }
@@ -1216,9 +1260,8 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
12161260 LogicalResult match (tosa::MatMulOp op) const override {
12171261 FailureOr<std::tuple<Value, bool , Value>> softmaxInputResult =
12181262 maybeSoftmax (op.getA ());
1219- if (failed (softmaxInputResult)) {
1263+ if (failed (softmaxInputResult))
12201264 return failure ();
1221- }
12221265
12231266 Value softmaxInput, currentSeqLen;
12241267 bool hasReduceOp;
@@ -1245,12 +1288,10 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
12451288 LLVM_DEBUG (llvm::dbgs ()
12461289 << " first matmul = " << maybeFirstMatMul.value () << " \n " );
12471290 LLVM_DEBUG (llvm::dbgs () << " hasReduceOp = " << hasReduceOp << " \n " );
1248- if (isDotProduct && hasReduceOp) {
1291+ if (isDotProduct && hasReduceOp)
12491292 return failure ();
1250- }
1251- if (!isDotProduct && !hasReduceOp) {
1293+ if (!isDotProduct && !hasReduceOp)
12521294 return failure ();
1253- }
12541295 } else {
12551296 LLVM_DEBUG (llvm::dbgs () << " first matmul not found\n " );
12561297 }
0 commit comments