@@ -2501,7 +2501,7 @@ static AffineMap getPartialResultMap(AffineMap map, AttentionOpDetail &opInfo) {
25012501FailureOr<SmallVector<Value>>
25022502OnlineAttentionOp::generateInitialTensorForPartialReduction (
25032503 OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
2504- ArrayRef< int > reductionDim ) {
2504+ const llvm::SetVector< unsigned > &reductionDims ) {
25052505 FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get (
25062506 getQueryMap (), getKeyMap (), getValueMap (), getOutputMap ());
25072507 if (failed (maybeOpInfo)) {
@@ -2555,8 +2555,10 @@ OnlineAttentionOp::generateInitialTensorForPartialReduction(
25552555}
25562556
25572557FailureOr<TilingResult> OnlineAttentionOp::tileToPartialReduction (
2558- OpBuilder &b, Location loc, ValueRange init, ArrayRef<OpFoldResult> offsets,
2559- ArrayRef<OpFoldResult> sizes, ArrayRef<int > reductionDims) {
2558+ OpBuilder &b, Location loc, ReductionTilingStrategy strategy,
2559+ ValueRange init, ArrayRef<OpFoldResult> offsets,
2560+ ArrayRef<OpFoldResult> sizes,
2561+ const llvm::SetVector<unsigned > &reductionDims) {
25602562 FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get (
25612563 getQueryMap (), getKeyMap (), getValueMap (), getOutputMap ());
25622564 if (failed (maybeOpInfo)) {
@@ -2707,10 +2709,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc,
27072709 return genericOp.getResult (0 );
27082710}
27092711
2710- FailureOr<MergeResult>
2711- OnlineAttentionOp::mergeReductions (OpBuilder &b, Location loc,
2712- ValueRange partialReduce,
2713- ArrayRef<int > reductionDim) {
2712+ FailureOr<MergeResult> OnlineAttentionOp::mergeReductions (
2713+ OpBuilder &b, Location loc, ValueRange partialReduce,
2714+ const llvm::SetVector<unsigned > &reductionDims) {
27142715 FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get (
27152716 getQueryMap (), getKeyMap (), getValueMap (), getOutputMap ());
27162717 if (failed (maybeOpInfo)) {
@@ -2753,8 +2754,10 @@ OnlineAttentionOp::mergeReductions(OpBuilder &b, Location loc,
27532754
27542755LogicalResult OnlineAttentionOp::getPartialResultTilePosition (
27552756 OpBuilder &b, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2756- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2757- SmallVector<OpFoldResult> &resultSizes, ArrayRef<int > reductionDims) {
2757+ ArrayRef<OpFoldResult> sizes,
2758+ const llvm::SetVector<unsigned > &reductionDims,
2759+ SmallVector<OpFoldResult> &resultOffsets,
2760+ SmallVector<OpFoldResult> &resultSizes) {
27582761
27592762 FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get (
27602763 getQueryMap (), getKeyMap (), getValueMap (), getOutputMap ());
0 commit comments