@@ -3474,7 +3474,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34743474// / It checks if the first result dimension is a function of the first
34753475// / dimension.
34763476static bool isValidBatchDim (AffineMap bcastMap) {
3477- assert (bcastMap.getNumResults () == 3 && " Expected three result dim expr." );
34783477 AffineExpr exp = bcastMap.getResult (0 );
34793478 return exp.isFunctionOfDim (0 );
34803479}
@@ -3490,6 +3489,48 @@ static bool isValidOutputResultDim(AffineMap outputMap) {
34903489 exp2.isFunctionOfDim (nPos);
34913490}
34923491
3492+ // Check general validity of input indexing map.
3493+ static LogicalResult verifyInputMaps (BatchMatmulOp batchMatmulOp,
3494+ AffineMap opIndexingMap,
3495+ AffineMap defaultIndexingMap, bool isLHS) {
3496+ // Check the result dims are valid.
3497+ if (!isValidResultDimExprs (opIndexingMap, defaultIndexingMap))
3498+ return batchMatmulOp->emitOpError ()
3499+ << " Unexpected dim expression in map result." ;
3500+
3501+ // Check for valid number of result dims of input maps.
3502+ if (opIndexingMap.getNumResults () > 3 )
3503+ return batchMatmulOp->emitOpError ()
3504+ << " no. of result dim expression cannot exceed 3." ;
3505+
3506+ // Check if the requested broadcast is valid.
3507+ if (isBroadcasted (opIndexingMap, defaultIndexingMap)) {
3508+ if (!batchMatmulOp.isValidLhsRhsBroadcastMap (opIndexingMap, isLHS))
3509+ return batchMatmulOp->emitOpError () << " Invalid broadcast requested." ;
3510+ } else if (!isValidBatchDim (opIndexingMap)) {
3511+ return batchMatmulOp->emitOpError ()
3512+ << " Invalid batch dimension expression." ;
3513+ }
3514+ return success ();
3515+ }
3516+
3517+ // / This function checks if the given AffineMap for the output of a
3518+ // / BatchMatmulOp has exactly 3 result dimensions and if the output map result
3519+ // / dimensions are valid.
3520+ static LogicalResult verifyOutputMap (BatchMatmulOp batchMatmulOp,
3521+ AffineMap opIndexingMap) {
3522+ if (opIndexingMap.getNumResults () != 3 )
3523+ return batchMatmulOp->emitOpError ()
3524+ << " expects 3 dims, but got (" << opIndexingMap.getNumResults ()
3525+ << " )." ;
3526+
3527+ if (!isValidOutputResultDim (opIndexingMap))
3528+ return batchMatmulOp->emitOpError ()
3529+ << " Invalid output map result dimension." ;
3530+
3531+ return success ();
3532+ }
3533+
34933534// / Verifies the broadcast and transpose semantic specified by the explicit
34943535// / indexing map for the BatchMatmulOp \p op for each operand specified by \p
34953536// / opIndex.
@@ -3503,37 +3544,14 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
35033544
35043545 auto opIndexingMap = opIndexingMaps[opIndex];
35053546 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3506- // Check general validity of indexing map results.
3507- if (opIndex < 2 ) {
3508- if (!isValidResultDimExprs (opIndexingMap, defaultIndexingMap))
3509- return batchMatmulOp->emitOpError ()
3510- << " Unexpected dim expression in map result." ;
3511- // Check if the requested broadcast is valid.
3512- if (isBroadcasted (opIndexingMap, defaultIndexingMap)) {
3513- if (!batchMatmulOp.isValidLhsRhsBroadcastMap (opIndexingMap,
3514- opIndex == 0 )) {
3515- return batchMatmulOp->emitOpError () << " Invalid broadcast requested." ;
3516- }
3517- } else {
3518- // Check for valid number of result dims of input maps.
3519- if (opIndexingMap.getNumResults () != 3 )
3520- return batchMatmulOp->emitOpError ()
3521- << " no. of result dim expression cannot exceed 3." ;
3522-
3523- if (!isValidBatchDim (opIndexingMap))
3524- return batchMatmulOp->emitOpError ()
3525- << " Invalid batch dimension expression." ;
3526- }
3527- } else {
3528- // Check for valid number of result dims of output map.
3529- if (opIndexingMap.getNumResults () != 3 )
3530- return batchMatmulOp->emitOpError ()
3531- << " no. of result dim expression cannot exceed 3." ;
35323547
3533- if (!isValidOutputResultDim (opIndexingMap))
3534- return batchMatmulOp->emitOpError ()
3535- << " Invalid output map result dimension." ;
3536- }
3548+ if (opIndex == 2 && failed (verifyOutputMap (batchMatmulOp, opIndexingMap)))
3549+ return failure ();
3550+
3551+ if (failed (verifyInputMaps (batchMatmulOp, opIndexingMap, defaultIndexingMap,
3552+ opIndex == 0 )))
3553+ return failure ();
3554+
35373555 return success ();
35383556}
35393557
0 commit comments