@@ -3436,8 +3436,12 @@ static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
34363436 return explicitSet == defaultSet;
34373437}
34383438
3439- // / Returns true if the \p explictMap is broadcasted with respect to the
3440- // / \p defaultMap.
3439+ // / Check if the user defined map is valid broadcast map. Here broadcast
3440+ // / indexing maps are defined in context of corresponding default indexing maps
3441+ // / for the given Op. This way the check becomes very simple i.e just check the
3442+ // / number of result dims.
3443+ // / Returns true if the explictMap is broadcasted with respect to the
3444+ // / defaultMap.
34413445static bool isBroadcasted (AffineMap explictMap, AffineMap defaultMap) {
34423446 return explictMap.getNumResults () < defaultMap.getNumResults ();
34433447}
@@ -3458,10 +3462,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34583462 return matmulOp->emitOpError ()
34593463 << " Unexpected dim expression in map result." ;
34603464
3461- // Check if the user defined map is valid broadcast map. Here broadcast
3462- // indexing maps are defined in context of corresponding default indexing maps
3463- // for the given Op. This way the check becomes very simple i.e just check the
3464- // number of result dims.
34653465 if (isBroadcasted (opIndexingMap, defaultIndexingMap)) {
34663466 if (!matmulOp.isValidLhsRhsBroadcastMap (opIndexingMap)) {
34673467 return matmulOp->emitOpError ()
@@ -3527,8 +3527,7 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
35273527}
35283528
35293529// / Verifies the broadcast and transpose semantic specified by the explicit
3530- // / indexing map for the BatchMatmulOp \p op for each operand specified by \p
3531- // / opIndex.
3530+ // / indexing map for the BatchMatmulOp op for each operand specified by opIndex.
35323531static LogicalResult
35333532verifyExtendedBatchMatmulSemantic (BatchMatmulOp batchMatmulOp,
35343533 unsigned opIndex) {
@@ -3934,7 +3933,7 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
39343933 return defaultMaps != explicitMaps;
39353934}
39363935
3937- // / Returns true if the given broadcast map \p bcastMap is valid for this op.
3936+ // / Returns true if the given broadcast map bcastMap is valid for this op.
39383937bool BatchMatmulOp::isValidLhsRhsBroadcastMap (AffineMap bcastMap, bool isLHS) {
39393938 assert (bcastMap.getNumResults () < 3 &&
39403939 " Expected less than 3 result dim expr." );
@@ -3960,16 +3959,15 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
39603959 RegionBuilderHelper helper (b, block);
39613960 SmallVector<Value> yields;
39623961
3963- Value value1 =
3964- helper.buildTypeFn (TypeFn::cast_signed, block.getArgument (2 ).getType (),
3965- block.getArgument (0 ));
3966- Value value2 =
3967- helper.buildTypeFn (TypeFn::cast_signed, block.getArgument (2 ).getType (),
3968- block.getArgument (1 ));
3969- Value value3 = helper.buildBinaryFn (BinaryFn::mul, value1, value2);
3970- Value value4 =
3971- helper.buildBinaryFn (BinaryFn::add, block.getArgument (2 ), value3);
3972- yields.push_back (value4);
3962+ auto toType = block.getArgument (2 ).getType ();
3963+ Value castValA =
3964+ helper.buildTypeFn (TypeFn::cast_signed, toType, block.getArgument (0 ));
3965+ Value castValB =
3966+ helper.buildTypeFn (TypeFn::cast_signed, toType, block.getArgument (1 ));
3967+ Value mulVal = helper.buildBinaryFn (BinaryFn::mul, castValA, castValB);
3968+ Value addVal =
3969+ helper.buildBinaryFn (BinaryFn::add, block.getArgument (2 ), mulVal);
3970+ yields.push_back (addVal);
39733971 helper.yieldOutputs (yields);
39743972}
39753973
0 commit comments