@@ -59,13 +59,6 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
5959 struct IntelDPASCapability dpasCap,
6060 const ArrayRef<int64_t > shape,
6161 unsigned numWarps) {
62- auto rank = shape.size ();
63- // Early exit for batched matmul
64- // TODO: current strategy is same as upstream, there could be better strategy
65- // when batch < numWarps
66- if (rank == 3 )
67- return {numWarps, 1 , 1 };
68-
6962 auto filter = [&dotOp](Operation *op) {
7063 return op->getParentRegion () == dotOp->getParentRegion ();
7164 };
@@ -76,29 +69,39 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
7669 if (isa<DotOp>(op) && (op != dotOp))
7770 return {numWarps, 1 };
7871
79- SmallVector<unsigned > ret{1 , 1 };
80- SmallVector<int64_t > shapePerWarp{dpasCap.repeatCount , dpasCap.executionSize };
72+ size_t rank = shape.size ();
73+ SmallVector<unsigned > ret (rank, 1 );
74+
75+ if (rank == 3 ) {
76+ int batchWarp = numWarps;
77+ while (batchWarp > shape[0 ])
78+ batchWarp /= 2 ;
79+ ret[0 ] = batchWarp;
80+ numWarps /= batchWarp;
81+ }
8182
8283 // Try to find a proper tiling shape for the dot operation.
8384 // It doubles the warp number in col or row in each time based on column to
8485 // width ratio.
8586 // By this, we can minimize the duplication of the dot operands A and B.
87+ SmallVector<int64_t > shapePerWarp{dpasCap.repeatCount , dpasCap.executionSize };
8688 uint32_t rowColRatio =
8789 ceil<uint32_t >(dpasCap.repeatCount , dpasCap.executionSize );
8890 uint32_t colRowRatio =
8991 ceil<uint32_t >(dpasCap.executionSize , dpasCap.repeatCount );
9092
93+ int rowDim = rank - 2 , colDim = rank - 1 ;
9194 do {
92- if (ret[0 ] * ret[1 ] >= numWarps)
95+ if (ret[rowDim ] * ret[colDim ] >= numWarps)
9396 break ;
94- if (shape[0 ] / (shapePerWarp[0 ] * colRowRatio) / ret[0 ] >=
95- shape[1 ] / (shapePerWarp[1 ] * rowColRatio) / ret[1 ]) {
96- if (ret[0 ] < shape[0 ] / shapePerWarp[0 ])
97- ret[0 ] *= 2 ;
97+ if (shape[rowDim ] / (shapePerWarp[0 ] * colRowRatio) / ret[rowDim ] >=
98+ shape[colDim ] / (shapePerWarp[1 ] * rowColRatio) / ret[colDim ]) {
99+ if (ret[rowDim ] < shape[rowDim ] / shapePerWarp[0 ])
100+ ret[rowDim ] *= 2 ;
98101 else
99- ret[1 ] *= 2 ;
102+ ret[colDim ] *= 2 ;
100103 } else {
101- ret[1 ] *= 2 ;
104+ ret[colDim ] *= 2 ;
102105 }
103106 } while (true );
104107
0 commit comments