@@ -213,6 +213,42 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op,
213213 return success ();
214214}
215215
216+ static bool isTypeSupported (Type outType, Type operandAType,
217+ Type operandBType) {
218+ if (!outType.isF32 () && !outType.isSignedInteger (32 ))
219+ return false ;
220+
221+ if (outType.isF32 ()) {
222+ if (!(operandAType.isF32 () && operandBType.isF32 ()) &&
223+ !(operandAType.isBF16 () && operandBType.isBF16 ()))
224+ return false ;
225+ }
226+ if (outType.isSignedInteger (32 )) {
227+ if (!(operandAType.isSignedInteger (8 ) ||
228+ operandAType.isUnsignedInteger (8 )) &&
229+ (operandBType.isSignedInteger (8 ) || operandBType.isUnsignedInteger (8 )))
230+ return false ;
231+ }
232+ return true ;
233+ }
234+
235+ // TODO(haixin): could use compiler-wide VNNI utils?
236+ static bool isInVnniLayout (ShapedType type) {
237+ if (!type.getElementType ().isBF16 () &&
238+ !type.getElementType ().isSignedInteger (8 ) &&
239+ !type.getElementType ().isUnsignedInteger (8 ))
240+ return false ;
241+
242+ auto blockingFactor = 0 ;
243+ if (type.getElementType ().isBF16 ())
244+ blockingFactor = 2 ;
245+ else if (type.getElementType ().isSignedInteger (8 ) ||
246+ type.getElementType ().isUnsignedInteger (8 ))
247+ blockingFactor = 4 ;
248+
249+ return type.getShape ().back () == blockingFactor;
250+ }
251+
216252// ///////////////////////////////////////////////////
217253// Start of BrgemmOp
218254
@@ -308,9 +344,8 @@ static inline ArrayRef<int64_t> getShapedValueShape(Value val) {
308344 assert ((llvm::isa<TensorType>(val.getType ()) ||
309345 llvm::isa<MemRefType>(val.getType ())) &&
310346 " Expecting shaped value" );
311- if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType ())) {
347+ if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType ()))
312348 return tensorTy.getShape ();
313- }
314349 auto memrefTy = dyn_cast_or_null<MemRefType>(val.getType ());
315350 return memrefTy.getShape ();
316351}
@@ -331,15 +366,27 @@ LogicalResult BrgemmOp::verify() {
331366 return op.emitOpError ()
332367 << " expect inputs and its related info to be size 2\n " ;
333368
369+ auto elemTypeA = getElementTypeOrSelf (ins[0 ]);
370+ auto elemTypeB = getElementTypeOrSelf (ins[1 ]);
371+ auto elemTypeC = getElementTypeOrSelf (out);
372+ if (!isTypeSupported (elemTypeC, elemTypeA, elemTypeB))
373+ return op.emitOpError () << " unsupported input matrix types\n " ;
374+
334375 ArrayRef<int64_t > dimA = getShapedValueShape (ins[0 ]);
335376 ArrayRef<int64_t > dimB = getShapedValueShape (ins[1 ]);
336377 ArrayRef<int64_t > dimC = getShapedValueShape (out);
337378 if (dimA.size () != 3 )
338379 return op.emitOpError () << " expect input A to be 3D\n " ;
339- if (dimB.size () != 3 && dimB.size () != 4 )
340- return op.emitOpError () << " expect input B to be 3D or 4D\n " ;
341- if (dimB.size () == 4 && (dimB[3 ] != 2 && dimB[3 ] != 4 ))
342- return op.emitOpError () << " expect input B vnni step to be 2 or 4\n " ;
380+ if (!elemTypeB.isF32 ()) {
381+ if (dimB.size () != 4 ||
382+ !isInVnniLayout (dyn_cast<ShapedType>(ins[1 ].getType ())))
383+ return op.emitOpError ()
384+ << " expect a 4d VNNI input B for non-F32 operand: " << ins[1 ];
385+ } else {
386+ if (dimB.size () != 3 )
387+ return op.emitOpError ()
388+ << " expect a 3d input B for F32 operand: " << ins[1 ];
389+ }
343390 if (dimC.size () != 2 )
344391 return op.emitOpError () << " expect input C to be 2D\n " ;
345392 for (auto dim : batchDims)
@@ -558,42 +605,6 @@ LogicalResult BrgemmDispatchOp::verify() {
558605// ///////////////////////////////////////////////////
559606// Start of BrgemmExecuteOp
560607
561- // TODO(haixin): could use compiler-wide VNNI utils?
562- static bool isInVnniLayout (MemRefType memref) {
563- if (!memref.getElementType ().isBF16 () &&
564- !memref.getElementType ().isSignedInteger (8 ) &&
565- !memref.getElementType ().isUnsignedInteger (8 ))
566- return false ;
567-
568- auto blockingFactor = 0 ;
569- if (memref.getElementType ().isBF16 ())
570- blockingFactor = 2 ;
571- else if (memref.getElementType ().isSignedInteger (8 ) ||
572- memref.getElementType ().isUnsignedInteger (8 ))
573- blockingFactor = 4 ;
574-
575- return memref.getShape ().back () == blockingFactor;
576- }
577-
578- static bool isTypeSupported (Type outType, Type operandAType,
579- Type operandBType) {
580- if (!outType.isF32 () && !outType.isSignedInteger (32 ))
581- return false ;
582-
583- if (outType.isF32 ()) {
584- if (!(operandAType.isF32 () && operandBType.isF32 ()) &&
585- !(operandAType.isBF16 () && operandBType.isBF16 ()))
586- return false ;
587- }
588- if (outType.isSignedInteger (32 )) {
589- if (!(operandAType.isSignedInteger (8 ) ||
590- operandAType.isUnsignedInteger (8 )) &&
591- (operandBType.isSignedInteger (8 ) || operandBType.isUnsignedInteger (8 )))
592- return false ;
593- }
594- return true ;
595- }
596-
597608LogicalResult BrgemmExecuteOp::verify () {
598609 BrgemmExecuteOp &brgemmOp = *this ;
599610
0 commit comments