@@ -670,16 +670,16 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
670670      case  4 :
671671        op.setScalesIdxB (val);
672672        break ;
673-       default :  
673+       default :
674674        break ;
675675      }
676676    };
677677
678678    //  Obtain flat index from offsets and shape.
679679    auto  getIdxFromExtract = [](vector::ExtractOp op) {
680680      ShapedType ty = dyn_cast<ShapedType>(op.getOperand (0 ).getType ());
681-       int  cumul = 1 ;
682-       int  idx = 0 ;
681+       int64_t  cumul = 1 ;
682+       int64_t  idx = 0 ;
683683      for  (auto  [offset, size] :
684684           reverse (llvm::zip_equal (op.getStaticPosition (), ty.getShape ()))) {
685685        idx += offset * cumul;
@@ -720,33 +720,37 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
720720    for  (auto  opIdx : SmallVector<int64_t >({3 , 4 })) {
721721      auto  insertOp = op.getOperand (opIdx).getDefiningOp <vector::InsertOp>();
722722      if  (!insertOp) {
723-         return  failure ();
723+         return  rewriter.notifyMatchFailure (op,
724+                                            " defining op not a vector.insert" 
724725      }
725726      if  (llvm::any_of (insertOp.getResult ().getUses (), checkIfUnpackable)) {
726-         return  failure ();
727+         return  rewriter.notifyMatchFailure (op,
728+                                            " some scaled mfma's already packed" 
727729      }
728730
729731      auto  extractOp =
730732          insertOp.getOperand (0 ).getDefiningOp <vector::ExtractOp>();
731733      if  (!extractOp) {
732-         return  failure ();
734+         return  rewriter.notifyMatchFailure (op,
735+                                            " defining op not a vector.extract" 
733736      }
734737
735738      Value scaleSrc = extractOp.getOperand (0 );
736-       auto  stype = dyn_cast<ShapedType >(scaleSrc.getType ());
739+       auto  stype = dyn_cast<VectorType >(scaleSrc.getType ());
737740      if  (!stype) {
738-         return  failure ( );
741+         return  rewriter. notifyMatchFailure (op,  " not a shaped type " 
739742      }
740743      //  We do not handle dynamic dims yet, assume that the input is padded to
741744      //  a static shape now.
742-       if  (llvm::any_of (llvm::seq< int64_t >( 0 ,  stype.getRank ()), 
743-                        [&]( int64_t  i) {  return  stype. isDynamicDim (i); })) { 
744-         return   failure ( );
745+       if  (! stype.hasStaticShape ()) { 
746+         return  rewriter. notifyMatchFailure (op, 
747+                                             " dynamic dims not yet supported " 
745748      }
746749
747750      int64_t  numElements = stype.getNumElements ();
748-       if  (numElements <= 4 ) {
749-         return  failure ();
751+       if  (numElements <= 4  || !(numElements % 4 )) {
752+         return  rewriter.notifyMatchFailure (
753+             op, " no packing if # of scales less than or indivisible by four" 
750754      }
751755
752756      Type newSrcType = VectorType::get (
@@ -760,7 +764,8 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
760764          loc, newScaleSrc, SmallVector<int64_t >{offsets[0 ], 0 },
761765          SmallVector<int64_t >{1 , 4 }, SmallVector<int64_t >{1 , 1 });
762766      Value scale = rewriter.create <vector::ShapeCastOp>(loc, scaleTy, extract);
763-       op.setOperand (opIdx, scale);
767+       rewriter.modifyOpInPlace (
768+           op, [&op, opIdx, scale] { op->setOperand (opIdx, scale); });
764769      setOpsel (opIdx, offsets[1 ]);
765770    }
766771    return  success ();
0 commit comments