@@ -3426,11 +3426,10 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
34263426 return arith::ConstantOp::materialize (builder, value, type, loc);
34273427}
34283428
3429- // / Returns true if the result AffineExpr of the \p explicitMap is same as \p
3430- // / defaultMap.
3431- static bool isValidResultDimExprs (AffineMap explictMap, AffineMap defaultMap) {
3432- auto explicitRange = explictMap.getResults ();
3433- auto defaultRange = defaultMap.getResults ();
3429+ // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3430+ static bool areResultExprsSubsetOf (AffineMap subMap, AffineMap fullMap) {
3431+ auto explicitRange = subMap.getResults ();
3432+ auto defaultRange = fullMap.getResults ();
34343433 DenseSet<AffineExpr> explicitSet (explicitRange.begin (), explicitRange.end ());
34353434 DenseSet<AffineExpr> defaultSet (defaultRange.begin (), defaultRange.end ());
34363435 llvm::set_union (explicitSet, defaultSet);
@@ -3455,7 +3454,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34553454 auto opIndexingMap = opIndexingMaps[opIndex];
34563455 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
34573456 // Check general validity of indexing map results.
3458- if (!isValidResultDimExprs (opIndexingMap, defaultIndexingMap))
3457+ if (!areResultExprsSubsetOf (opIndexingMap, defaultIndexingMap))
34593458 return matmulOp->emitOpError ()
34603459 << " Unexpected dim expression in map result." ;
34613460
@@ -3470,44 +3469,31 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34703469 return success ();
34713470}
34723471
3473- // / Checks if the given AffineMap represents a valid batch dimension.
3474- // / It checks if the first result dimension is a function of the first
3475- // / dimension.
3476- static bool isValidBatchDim (AffineMap bcastMap) {
3477- AffineExpr exp = bcastMap.getResult (0 );
3478- return exp.isFunctionOfDim (0 );
3479- }
3480-
3481- // / Checks if the given AffineMap's result dimensions are valid output result
3482- // / dimensions.
3483- static bool isValidOutputResultDim (AffineMap outputMap) {
3484- enum Indices { batchPos, mPos , nPos };
3485- AffineExpr exp0 = outputMap.getResult (batchPos);
3486- AffineExpr exp1 = outputMap.getResult (mPos );
3487- AffineExpr exp2 = outputMap.getResult (nPos);
3488- return exp0.isFunctionOfDim (batchPos) && exp1.isFunctionOfDim (mPos ) &&
3489- exp2.isFunctionOfDim (nPos);
3490- }
3491-
34923472// Check general validity of input indexing map.
34933473static LogicalResult verifyInputMaps (BatchMatmulOp batchMatmulOp,
34943474 AffineMap opIndexingMap,
34953475 AffineMap defaultIndexingMap, bool isLHS) {
34963476 // Check the result dims are valid.
3497- if (!isValidResultDimExprs (opIndexingMap, defaultIndexingMap))
3477+ if (!areResultExprsSubsetOf (opIndexingMap, defaultIndexingMap))
34983478 return batchMatmulOp->emitOpError ()
3499- << " Unexpected dim expression in map result." ;
3479+ << " Unexpected result dim expression (outside the set of default "
3480+ " result dims)." ;
35003481
35013482 // Check for valid number of result dims of input maps.
35023483 if (opIndexingMap.getNumResults () > 3 )
35033484 return batchMatmulOp->emitOpError ()
3504- << " no. of result dim expression cannot exceed 3." ;
3485+ << " no. of result dim expressions exceeds 3." ;
3486+
3487+ auto hasValidBatchDim = [](AffineMap map) {
3488+ AffineExpr batchDim = map.getResult (0 );
3489+ return batchDim.isFunctionOfDim (0 );
3490+ };
35053491
35063492 // Check if the requested broadcast is valid.
35073493 if (isBroadcasted (opIndexingMap, defaultIndexingMap)) {
35083494 if (!batchMatmulOp.isValidLhsRhsBroadcastMap (opIndexingMap, isLHS))
35093495 return batchMatmulOp->emitOpError () << " Invalid broadcast requested." ;
3510- } else if (!isValidBatchDim (opIndexingMap)) {
3496+ } else if (!hasValidBatchDim (opIndexingMap)) {
35113497 return batchMatmulOp->emitOpError ()
35123498 << " Invalid batch dimension expression." ;
35133499 }
@@ -3524,7 +3510,13 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
35243510 << " expects 3 dims, but got (" << opIndexingMap.getNumResults ()
35253511 << " )." ;
35263512
3527- if (!isValidOutputResultDim (opIndexingMap))
3513+ auto areValidOutputResultDim = [](AffineMap outputMap) {
3514+ return outputMap.getResult (0 ).isFunctionOfDim (0 ) &&
3515+ outputMap.getResult (1 ).isFunctionOfDim (1 ) &&
3516+ outputMap.getResult (2 ).isFunctionOfDim (2 );
3517+ };
3518+
3519+ if (!areValidOutputResultDim (opIndexingMap))
35283520 return batchMatmulOp->emitOpError ()
35293521 << " Invalid output map result dimension." ;
35303522
@@ -3941,7 +3933,8 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
39413933
39423934// / Returns true if the given broadcast map \p bcastMap is valid for this op.
39433935bool BatchMatmulOp::isValidLhsRhsBroadcastMap (AffineMap bcastMap, bool isLHS) {
3944- assert (bcastMap.getNumResults () < 3 && " Expected single result dim expr." );
3936+ assert (bcastMap.getNumResults () < 3 &&
3937+ " Expected less than 3 result dim expr." );
39453938 bool isValid = false ;
39463939 enum Indices { batchPos, mPos , nPos, kPos };
39473940 if (bcastMap.getNumResults () == 1 ) {
0 commit comments