@@ -37,8 +37,6 @@ void XeGPUDialect::initialize() {
3737 >();
3838}
3939
40- #define DEBUG_TYPE " xegpu"
41-
4240// / Generates instructions to compute offsets for a subgroup identified by
4341// / its multidimensional indices (sgId), using the specified subgroup layout
4442// / (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
@@ -788,56 +786,21 @@ SmallVector<int64_t> MemDescType::getStrides() {
788786 strides.push_back (cast<IntegerAttr>(attr).getInt ());
789787 }
790788
791- llvm::dbgs () << " DEBUG: matrixShape = [" ;
792- for (size_t i = 0 ; i < matrixShape.size (); ++i) {
793- llvm::dbgs () << matrixShape[i];
794- if (i < matrixShape.size () - 1 )
795- llvm::dbgs () << " , " ;
796- }
797- llvm::dbgs () << " ]\n " ;
798-
799- llvm::dbgs () << " DEBUG: strides = [" ;
800- for (size_t i = 0 ; i < strides.size (); ++i) {
801- llvm::dbgs () << strides[i];
802- if (i < strides.size () - 1 )
803- llvm::dbgs () << " , " ;
804- }
805- llvm::dbgs () << " ]\n " ;
806-
807789 SmallVector<int64_t > innerBlkShape = getBlockSize ();
808- llvm::dbgs () << " DEBUG: innerBlkShape = [" ;
809- for (size_t i = 0 ; i < innerBlkShape.size (); ++i) {
810- llvm::dbgs () << innerBlkShape[i];
811- if (i < innerBlkShape.size () - 1 )
812- llvm::dbgs () << " , " ;
813- }
814- llvm::dbgs () << " ]\n " ;
815790
816791 // get perm from FCD to LCD
817792 // perm[i] = the dim with i-th smallest stride
818793 SmallVector<int , 4 > perm =
819794 llvm::to_vector<4 >(llvm::seq<int >(0 , strides.size ()));
820795 llvm::sort (perm, [&](int a, int b) { return strides[a] < strides[b]; });
821796
822- llvm::dbgs () << " DEBUG: perm = [" ;
823- for (size_t i = 0 ; i < perm.size (); ++i) {
824- llvm::dbgs () << perm[i];
825- if (i < perm.size () - 1 )
826- llvm::dbgs () << " , " ;
827- }
828- llvm::dbgs () << " ]\n " ;
829-
830797 assert (strides[perm[0 ]] == 1 && " inner most dim must have stride 1" );
831798
832- SmallVector<int64_t > innerBlkStride = computeStrides (innerBlkShape);
833-
834- llvm::dbgs () << " DEBUG: innerBlkStride = [" ;
835- for (size_t i = 0 ; i < innerBlkStride.size (); ++i) {
836- llvm::dbgs () << innerBlkStride[i];
837- if (i < innerBlkStride.size () - 1 )
838- llvm::dbgs () << " , " ;
839- }
840- llvm::dbgs () << " ]\n " ;
799+ SmallVector<int64_t > innerBlkStride (innerBlkShape.size ());
800+ innerBlkStride[perm[0 ]] = 1 ;
801+ for (size_t i = 1 ; i < perm.size (); ++i)
802+ innerBlkStride[perm[i]] =
803+ innerBlkStride[perm[i - 1 ]] * innerBlkShape[perm[i - 1 ]];
841804
842805 // compute the original matrix shape using the stride info
843806 // and compute the number of blocks in each dimension
@@ -850,56 +813,22 @@ SmallVector<int64_t> MemDescType::getStrides() {
850813 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
851814 }
852815
853- llvm::dbgs () << " DEBUG: matrixShapeOrig = [" ;
854- for (size_t i = 0 ; i < matrixShapeOrig.size (); ++i) {
855- llvm::dbgs () << matrixShapeOrig[i];
856- if (i < matrixShapeOrig.size () - 1 )
857- llvm::dbgs () << " , " ;
858- }
859- llvm::dbgs () << " ]\n " ;
860-
861- llvm::dbgs () << " DEBUG: BlkShapeOrig = [" ;
862- for (size_t i = 0 ; i < BlkShapeOrig.size (); ++i) {
863- llvm::dbgs () << BlkShapeOrig[i];
864- if (i < BlkShapeOrig.size () - 1 )
865- llvm::dbgs () << " , " ;
866- }
867- llvm::dbgs () << " ]\n " ;
868-
869816 int64_t innerBlkSize = 1 ;
870817 for (auto s : innerBlkShape)
871818 innerBlkSize *= s;
872819
873- llvm::dbgs () << " DEBUG: innerBlkSize = " << innerBlkSize << " \n " ;
874-
875820 SmallVector<int64_t > outerBlkStride (matrixShape.size ());
876821 outerBlkStride[perm[0 ]] = innerBlkSize;
877822 for (size_t i = 0 ; i < perm.size () - 1 ; ++i) {
878823 outerBlkStride[perm[i + 1 ]] =
879824 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
880825 }
881826
882- llvm::dbgs () << " DEBUG: outerBlkStride = [" ;
883- for (size_t i = 0 ; i < outerBlkStride.size (); ++i) {
884- llvm::dbgs () << outerBlkStride[i];
885- if (i < outerBlkStride.size () - 1 )
886- llvm::dbgs () << " , " ;
887- }
888- llvm::dbgs () << " ]\n " ;
889-
890827 // combine the inner and outer strides
891828 SmallVector<int64_t > blockedStrides;
892829 blockedStrides.append (outerBlkStride.begin (), outerBlkStride.end ());
893830 blockedStrides.append (innerBlkStride.begin (), innerBlkStride.end ());
894831
895- llvm::dbgs () << " DEBUG: blockedStrides = [" ;
896- for (size_t i = 0 ; i < blockedStrides.size (); ++i) {
897- llvm::dbgs () << blockedStrides[i];
898- if (i < blockedStrides.size () - 1 )
899- llvm::dbgs () << " , " ;
900- }
901- llvm::dbgs () << " ]\n " ;
902-
903832 return blockedStrides;
904833}
905834
@@ -911,12 +840,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
911840 SmallVector<int64_t > blockShape = getBlockSize ();
912841 SmallVector<int64_t > strides = getStrides ();
913842
914- LLVM_DEBUG (llvm::dbgs () << " getLinearOffsets: blockShape=[" ;
915- llvm::interleaveComma (blockShape, llvm::dbgs ());
916- llvm::dbgs () << " ], strides=[" ;
917- llvm::interleaveComma (strides, llvm::dbgs ());
918- llvm::dbgs () << " ]\n " );
919-
920843 // blockshape equal to matrixshape means no blocking
921844 if (llvm::equal (blockShape, matrixShape)) {
922845 // remove the outer dims from strides
@@ -937,8 +860,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
937860 blockedOffsets.append (rems.begin (), rems.end ());
938861
939862 offsets = blockedOffsets;
940- LLVM_DEBUG (llvm::dbgs () << " getLinearOffsets: blocked offsets size="
941- << offsets.size () << " \n " );
942863 }
943864
944865 // Start with initial value as matrix descriptor's base offset.
@@ -949,9 +870,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
949870 linearOffset = arith::AddIOp::create (builder, loc, mulVal, linearOffset);
950871 }
951872
952- LLVM_DEBUG (llvm::dbgs () << " getLinearOffsets: final linearOffset="
953- << linearOffset << " \n " );
954-
955873 return linearOffset;
956874}
957875
0 commit comments