Skip to content

Commit 73c7462

Browse files
authored
[VectorExt] Add generic vectorization to iree_vector_ext.transfer_gather (#20476)
This patch adds support to GenericVectorization to vectorize some special gather-like generic operations to transfer_gather. This support is added behind a flag and should not affect any existing pipelines.
1 parent 68a2085 commit 73c7462

File tree

9 files changed

+594
-7
lines changed

9 files changed

+594
-7
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ iree_compiler_cc_library(
191191
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
192192
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets",
193193
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
194+
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms",
194195
"//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
195196
"//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
196197
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ iree_cc_library(
226226
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
227227
iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets
228228
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
229+
iree::compiler::Codegen::Dialect::VectorExt::Transforms::VectorExtTransforms
229230
iree::compiler::Codegen::Interfaces::BufferizationInterfaces
230231
iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
231232
iree::compiler::Codegen::Interfaces::UKernelOpInterface

compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "iree/compiler/Codegen/Common/Passes.h"
88
#include "iree/compiler/Codegen/Common/TileSizeSelection.h"
9+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
10+
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Transforms.h"
911
#include "iree/compiler/Codegen/Utils/Utils.h"
1012
#include "mlir/Dialect/Affine/LoopUtils.h"
1113
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
@@ -99,8 +101,9 @@ class GenericVectorizationPass final
99101
GenericVectorizationPass>::GenericVectorizationPassBase;
100102

101103
void getDependentDialects(DialectRegistry &registry) const override {
102-
registry.insert<tensor::TensorDialect, linalg::LinalgDialect,
103-
vector::VectorDialect>();
104+
registry
105+
.insert<tensor::TensorDialect, linalg::LinalgDialect,
106+
vector::VectorDialect, IREE::VectorExt::IREEVectorExtDialect>();
104107
}
105108
void runOnOperation() override;
106109
};
@@ -156,8 +159,16 @@ void GenericVectorizationPass::runOnOperation() {
156159
}
157160
// Pad scalable dims with `false` to match the vector sizes.
158161
scalableVecDims.resize(vectorSizes.size());
159-
(void)linalg::vectorize(rewriter, op, vectorSizes, scalableVecDims,
160-
vectorizeGatherAccesses);
162+
163+
// Try to vectorize to transfer_gather, if possible.
164+
if (isa<linalg::GenericOp>(op) && vectorizeToTransferGather) {
165+
(void)IREE::VectorExt::vectorizeGatherLikeGenericToTransferGather(
166+
rewriter, cast<linalg::GenericOp>(op), vectorSizes, scalableVecDims,
167+
vectorizeGatherAccesses);
168+
} else {
169+
(void)linalg::vectorize(rewriter, op, vectorSizes, scalableVecDims,
170+
vectorizeGatherAccesses);
171+
}
161172
};
162173

163174
{

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ def GenericVectorizationPass :
352352
"Rewrite all tensor.pad ops in the function to vector form.">,
353353
Option<"vectorizeGatherAccesses", "vectorize-gather-accesses", "bool", /*default=*/"false",
354354
"Enable vectorizaiton of operations that may generate vector.gather operations.">,
355+
Option<"vectorizeToTransferGather", "vectorize-to-transfer-gather", "bool", /*default=*/"false",
356+
"Enables vectorization of gather-like operations that may generate iree_vector_ext.transfer_gather">,
355357
Option<"enableCleanup", "enable-cleanup", "bool",/*default=*/"true",
356358
"Enable cleanups after vectorization. The patterns touch the structure"
357359
"generated from tiling so it affects later steps like bufferization and vector hoisting.">,

compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization.mlir

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-generic-vectorization))" --split-input-file %s | FileCheck %s
22
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-generic-vectorization{enable-vector-masking=true vectorize-padding=true}))" --split-input-file %s | FileCheck %s -check-prefix=CHECK-MASK
33
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-generic-vectorization{fold-cast-into-contract=true}))" --split-input-file %s | FileCheck %s -check-prefix=CHECK-FOLD
4+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-to-transfer-gather=true}))" --split-input-file %s | FileCheck %s --check-prefix=CHECK-GATHER
45

56
func.func @matmul(%lhs: tensor<3x4xf16>, %rhs: tensor<4x5xf16>, %acc: tensor<3x5xf32>) -> tensor<3x5xf32> {
67
%result = linalg.matmul ins(%lhs, %rhs: tensor<3x4xf16>, tensor<4x5xf16>) outs(%acc: tensor<3x5xf32>) -> tensor<3x5xf32>
@@ -533,3 +534,172 @@ func.func @depthwise_conv_fold_away_masking(%arg0: tensor<1x68x120x96xf32>, %arg
533534
// CHECK-MASK: vector.fma
534535
// CHECK-MASK-NOT: vector.create_mask
535536
// CHECK-MASK-NOT: vector.constant_mask
537+
538+
// -----
539+
540+
!storage = tensor<8192x8xf16>
541+
!ind = tensor<128xi64>
542+
!x = tensor<128x8xf16>
543+
544+
#gather = {
545+
indexing_maps = [affine_map<(page, vec) -> (page)>,
546+
affine_map<(page, vec) -> (page, vec)>],
547+
iterator_types = ["parallel", "parallel"]
548+
}
549+
550+
func.func @paged_gather_read(%storage : !storage, %ind: !ind) -> !x {
551+
%x = tensor.empty() : !x
552+
%x_g = linalg.generic #gather
553+
ins(%ind : !ind)
554+
outs(%x : !x) {
555+
^bb0(%page: i64, %out: f16):
556+
%pageidx = arith.index_cast %page : i64 to index
557+
%vec = linalg.index 1 : index
558+
%extracted = tensor.extract %storage[%pageidx, %vec] : !storage
559+
linalg.yield %extracted : f16
560+
} -> !x
561+
return %x_g : !x
562+
}
563+
564+
// CHECK-GATHER-LABEL: @paged_gather_read
565+
// CHECK-GATHER-SAME: %[[ARG0:.+]]: tensor<8192x8xf16>, %[[ARG1:.+]]: tensor<128xi64>
566+
// CHECK-GATHER: %[[INDEX_LOAD:.+]] = vector.transfer_read %[[ARG1]]
567+
// CHECK-GATHER: %[[INDEX_CAST:.+]] = arith.index_cast %[[INDEX_LOAD]] : vector<128xi64> to vector<128xindex>
568+
// CHECK-GATHER: %[[GATHER:.+]] = iree_vector_ext.transfer_gather %[[ARG0]]
569+
// CHECK-GATHER-SAME: [%[[INDEX_CAST]]: vector<128xindex>, None]
570+
// CHECK-GATHER: vector.transfer_write %[[GATHER]], %{{.*}}
571+
572+
// -----
573+
574+
!storage = tensor<8192x8xf16>
575+
!x = tensor<128x8xf16>
576+
577+
#gather = {
578+
indexing_maps = [affine_map<(page, vec) -> (page, vec)>],
579+
iterator_types = ["parallel", "parallel"]
580+
}
581+
582+
func.func @contiguous_gather_read(%storage : !storage) -> !x {
583+
%x = tensor.empty() : !x
584+
%x_g = linalg.generic #gather
585+
outs(%x : !x) {
586+
^bb0(%out: f16):
587+
%pageidx = linalg.index 0 : index
588+
%vec = linalg.index 1 : index
589+
%extracted = tensor.extract %storage[%pageidx, %vec] : !storage
590+
linalg.yield %extracted : f16
591+
} -> !x
592+
return %x_g : !x
593+
}
594+
595+
// CHECK-GATHER-LABEL: @contiguous_gather_read
596+
// CHECK-GATHER-SAME: %[[ARG0:.+]]: tensor<8192x8xf16>
597+
// CHECK-GATHER: %[[GATHER:.+]] = iree_vector_ext.transfer_gather %[[ARG0]]
598+
// CHECK-GATHER-SAME: [None, None]
599+
// CHECK-GATHER: vector.transfer_write %[[GATHER]], %{{.*}}
600+
601+
// -----
602+
603+
!storage = tensor<8192x8xf16>
604+
!ind = tensor<128xi64>
605+
!x = tensor<128x8xf16>
606+
607+
#gather = {
608+
indexing_maps = [affine_map<(page, vec) -> (page)>,
609+
affine_map<(page, vec) -> (page, vec)>],
610+
iterator_types = ["parallel", "parallel"]
611+
}
612+
613+
func.func @negative_strided_paged_gather_read(%storage : !storage, %ind: !ind) -> !x {
614+
%x = tensor.empty() : !x
615+
%c2 = arith.constant 2 : index
616+
%x_g = linalg.generic #gather
617+
ins(%ind : !ind)
618+
outs(%x : !x) {
619+
^bb0(%page: i64, %out: f16):
620+
%pageidx = arith.index_cast %page : i64 to index
621+
%vec = linalg.index 1 : index
622+
%strided_vec = arith.muli %vec, %c2 : index
623+
%extracted = tensor.extract %storage[%pageidx, %strided_vec] : !storage
624+
linalg.yield %extracted : f16
625+
} -> !x
626+
return %x_g : !x
627+
}
628+
629+
// For now, the vectorizer does not walk back on binary ops to find a mapping
630+
// from the iteration space to the memory space. This can be improved in future.
631+
// CHECK-GATHER-LABEL: @negative_strided_paged_gather_read
632+
// CHECK-GATHER: linalg.generic
633+
634+
// -----
635+
636+
!storage = tensor<8192x8xf16>
637+
!ind0 = tensor<128xi64>
638+
!ind1 = tensor<8xi64>
639+
!x = tensor<128x8xf16>
640+
641+
#gather = {
642+
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
643+
affine_map<(d0, d1) -> (d1)>,
644+
affine_map<(d0, d1) -> (d0, d1)>],
645+
iterator_types = ["parallel", "parallel"]
646+
}
647+
648+
func.func @full_gather_read(%storage : !storage, %ind0: !ind0, %ind1 : !ind1) -> !x {
649+
%x = tensor.empty() : !x
650+
%x_g = linalg.generic #gather
651+
ins(%ind0, %ind1 : !ind0, !ind1)
652+
outs(%x : !x) {
653+
^bb0(%id0: i64, %id1 : i64, %out: f16):
654+
%idx0 = arith.index_cast %id0 : i64 to index
655+
%idx1 = arith.index_cast %id1 : i64 to index
656+
%extracted = tensor.extract %storage[%idx0, %idx1] : !storage
657+
linalg.yield %extracted : f16
658+
} -> !x
659+
return %x_g : !x
660+
}
661+
662+
// CHECK-GATHER-LABEL: @full_gather_read
663+
// CHECK-GATHER-SAME: %[[ARG0:.+]]: tensor<8192x8xf16>, %[[ARG1:.+]]: tensor<128xi64>, %[[ARG2:.+]]: tensor<8xi64>
664+
// CHECK-GATHER-DAG: %[[IDX0:.+]] = vector.transfer_read %[[ARG1]]
665+
// CHECK-GATHER-DAG: %[[IDX1:.+]] = vector.transfer_read %[[ARG2]]
666+
// CHECK-GATHER-DAG: %[[CAST0:.+]] = arith.index_cast %[[IDX0]] : vector<128xi64> to vector<128xindex>
667+
// CHECK-GATHER-DAG: %[[CAST1:.+]] = arith.index_cast %[[IDX1]] : vector<8xi64> to vector<8xindex>
668+
// CHECK-GATHER-DAG: %[[GATHER:.+]] = iree_vector_ext.transfer_gather %[[ARG0]]
669+
// CHECK-GATHER-SAME: [%[[CAST0]]: vector<128xindex>, %[[CAST1]]: vector<8xindex>]
670+
// CHECK-GATHER: vector.transfer_write %[[GATHER]], %{{.*}}
671+
672+
// -----
673+
674+
!storage = tensor<8192x8xf16>
675+
!ind0 = tensor<128xi64>
676+
!ind1 = tensor<8xi64>
677+
!x = tensor<128x8xf16>
678+
679+
#gather = {
680+
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
681+
affine_map<(d0, d1) -> (d1)>,
682+
affine_map<(d0, d1) -> (d0, d1)>,
683+
affine_map<(d0, d1) -> (d0, d1)>],
684+
iterator_types = ["parallel", "parallel"]
685+
}
686+
687+
func.func @multi_extract(%storage : !storage, %storage2: !storage, %ind0: !ind0, %ind1 : !ind1) -> ( !x, !x ) {
688+
%x = tensor.empty() : !x
689+
%x_g, %x_g1 = linalg.generic #gather
690+
ins(%ind0, %ind1 : !ind0, !ind1)
691+
outs(%x, %x : !x, !x) {
692+
^bb0(%id0: i64, %id1 : i64, %out: f16, %out2: f16):
693+
%idx0 = arith.index_cast %id0 : i64 to index
694+
%idx1 = arith.index_cast %id1 : i64 to index
695+
%extracted = tensor.extract %storage[%idx0, %idx1] : !storage
696+
%idx2 = arith.index_cast %id0 : i64 to index
697+
%idx3 = arith.index_cast %id1 : i64 to index
698+
%extracted1 = tensor.extract %storage2[%idx2, %idx3] : !storage
699+
linalg.yield %extracted, %extracted1 : f16, f16
700+
} -> (!x, !x)
701+
return %x_g, %x_g1 : !x, !x
702+
}
703+
704+
// CHECK-GATHER-LABEL: @multi_extract
705+
// CHECK-GATHER-COUNT-2: transfer_gather

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,18 @@ iree_compiler_cc_library(
3737
hdrs = [
3838
"Passes.h",
3939
"Passes.h.inc",
40+
"Transforms.h",
4041
],
4142
deps = [
4243
":PassesIncGen",
4344
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
45+
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
4446
"@llvm-project//llvm:Support",
4547
"@llvm-project//mlir:ArithDialect",
4648
"@llvm-project//mlir:FunctionInterfaces",
4749
"@llvm-project//mlir:IR",
50+
"@llvm-project//mlir:LinalgDialect",
51+
"@llvm-project//mlir:LinalgTransforms",
4852
"@llvm-project//mlir:Pass",
4953
"@llvm-project//mlir:Support",
5054
"@llvm-project//mlir:TensorDialect",

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ iree_cc_library(
2525
HDRS
2626
"Passes.h"
2727
"Passes.h.inc"
28+
"Transforms.h"
2829
SRCS
2930
"Passes.cpp"
3031
"VectorExtFoldUnitExtentDims.cpp"
@@ -35,6 +36,8 @@ iree_cc_library(
3536
MLIRArithDialect
3637
MLIRFunctionInterfaces
3738
MLIRIR
39+
MLIRLinalgDialect
40+
MLIRLinalgTransforms
3841
MLIRPass
3942
MLIRSupport
4043
MLIRTensorDialect
@@ -44,6 +47,7 @@ iree_cc_library(
4447
MLIRVectorTransforms
4548
MLIRVectorUtils
4649
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
50+
iree::compiler::Dialect::LinalgExt::Utils
4751
PUBLIC
4852
)
4953

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_TRANSFORMS_H_
8+
#define IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_TRANSFORMS_H_
9+
10+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
11+
#include "mlir/IR/Builders.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
namespace mlir::iree_compiler::IREE::VectorExt {
15+
16+
LogicalResult vectorizeGatherLikeGenericToTransferGather(
17+
RewriterBase &rewriter, linalg::GenericOp linalgOp,
18+
ArrayRef<int64_t> vectorSizes = {}, ArrayRef<bool> scalableVecDims = {},
19+
bool vectorizeNDExtract = false);
20+
21+
void populateVectorTransferGatherLoweringPatterns(RewritePatternSet &patterns);
22+
23+
}; // namespace mlir::iree_compiler::IREE::VectorExt
24+
25+
#endif // IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_TRANSFORMS_H_

0 commit comments

Comments
 (0)