@@ -25,219 +25,6 @@ namespace mlir::iree_compiler::IREE::VectorExt {
2525
2626using VectorValue = TypedValue<VectorType>;
2727
28- bool PerDimLayoutAttr::contains (const LayoutDimension &dim) {
29- for (LayoutDimensionAttr label : getLabels ()) {
30- if (label.getValue () == dim)
31- return true ;
32- }
33- return false ;
34- }
35-
36- std::optional<int64_t > PerDimLayoutAttr::getShape (const LayoutDimension &dim) {
37- for (auto value : llvm::zip (getLabels (), getShapes ())) {
38- if (dim == std::get<0 >(value).getValue ())
39- return std::get<1 >(value);
40- }
41- return std::nullopt ;
42- }
43-
44- std::optional<int64_t > LayoutAttr::getShape (const LayoutDimension &dim) const {
45- for (PerDimLayoutAttr layout : getLayouts ()) {
46- std::optional<int64_t > maybeShape = layout.getShape (dim);
47- if (maybeShape)
48- return maybeShape.value ();
49- }
50- return std::nullopt ;
51- }
52-
53- // Get the SIMT Vector shape in the order specified by dims. If no dims are
54- // specified, then return an empty vector.
55- LogicalResult LayoutAttr::isValidLayout (ShapedType shapeTy,
56- Location loc) const {
57- ArrayRef<int64_t > shape = shapeTy.getShape ();
58- if (shape.size () != getRank ()) {
59- return emitError (loc, " Rank of vector (" )
60- << shape.size () << " ) does not match rank of layout (" << getRank ()
61- << " )." ;
62- }
63- for (auto [idx, layout] : llvm::enumerate (getLayouts ())) {
64- ArrayRef<int64_t > layoutShape = layout.getShapes ();
65- int64_t expectedShape =
66- std::reduce (layoutShape.begin (), layoutShape.end (),
67- static_cast <int64_t >(1 ), std::multiplies<int64_t >());
68- if (expectedShape != shape[idx]) {
69- std::string shapeStr;
70- llvm::raw_string_ostream shapeOs (shapeStr);
71- llvm::interleaveComma (shape, shapeOs);
72- std::string layoutStr;
73- llvm::raw_string_ostream layoutOs (layoutStr);
74- printStripped (layoutOs);
75- return emitError (loc, " Vector shape: [" )
76- << shapeStr << " ] does not match the layout (" << layoutStr
77- << " ) at dim " << idx
78- << " . Dimension expected by layout: " << expectedShape
79- << " actual: " << shape[idx];
80- }
81- }
82- return success ();
83- }
84-
85- // Project out the layout for the specified dimensions
86- // resulting in the layout for a lower dimensional vector.
87- VectorLayoutInterface LayoutAttr::project (ArrayRef<bool > droppedDims) const {
88- assert (droppedDims.size () == getRank () &&
89- " droppedDims size must match layout size" );
90-
91- ArrayRef<PerDimLayoutAttr> layouts = getLayouts ();
92- SmallVector<PerDimLayoutAttr> newLayouts;
93- for (auto pair : llvm::zip (droppedDims, layouts)) {
94- if (!std::get<0 >(pair))
95- newLayouts.push_back (std::get<1 >(pair));
96- }
97- return LayoutAttr::get (getContext (), newLayouts);
98- }
99-
100- // Permute the layout according to the provided permutation
101- // vector. The dimensionality of the layout remains the same.
102- VectorLayoutInterface LayoutAttr::permute (ArrayRef<int64_t > permutation) const {
103- assert (permutation.size () == getRank () &&
104- " permutation size must match layout rank" );
105-
106- ArrayRef<PerDimLayoutAttr> layouts = getLayouts ();
107- SmallVector<PerDimLayoutAttr> newLayouts;
108- for (unsigned index : permutation) {
109- assert (index >= 0 && index < getRank ());
110- newLayouts.push_back (layouts[index]);
111- }
112- return LayoutAttr::get (getContext (), newLayouts);
113- }
114-
115- // This function returns the distributed shape of the SIMT
116- // vector and evaluates it in the following order:
117- // BATCHX, BATCHY, VECTORY, VECTORX
118- // The vector dimensions are combined into a single SIMT
119- // vector dimension.
120- SmallVector<int64_t > LayoutAttr::getDistributedShape () const {
121- SmallVector<LayoutDimension> labels{
122- LayoutDimension::BATCHX, LayoutDimension::BATCHY,
123- LayoutDimension::VECTORY, LayoutDimension::VECTORX};
124- SmallVector<int64_t > simtVectorShape;
125- std::optional<int64_t > vectorShape;
126- for (LayoutDimension dim : labels) {
127- ArrayRef<PerDimLayoutAttr> layouts = getLayouts ();
128- for (PerDimLayoutAttr layout : layouts) {
129- if (!layout.contains (dim))
130- continue ;
131- int64_t shape = layout.getShape (dim).value ();
132- if (isVectorDimension (dim)) {
133- vectorShape = shape * vectorShape.value_or (1 );
134- continue ;
135- }
136- simtVectorShape.push_back (shape);
137- }
138- }
139- if (vectorShape)
140- simtVectorShape.push_back (vectorShape.value ());
141- return simtVectorShape;
142- }
143-
144- PerDimLayoutAttr LayoutAttr::getDimLayout (int64_t dim) const {
145- assert (dim >= 0 && dim < getRank ());
146- return getLayouts ()[dim];
147- }
148-
149- std::optional<int64_t > LayoutAttr::getBatchDim (int64_t dim) {
150- assert (dim < getRank ());
151- PerDimLayoutAttr layout = getDimLayout (dim);
152- for (auto [name, shape] :
153- llvm::zip_equal (layout.getLabels (), layout.getShapes ())) {
154- if (isBatchDimension (name.getValue ()))
155- return shape;
156- }
157- return std::nullopt ;
158- }
159-
160- std::optional<int64_t > LayoutAttr::getLaneDim (int64_t dim) {
161- assert (dim < getRank ());
162- PerDimLayoutAttr layout = getDimLayout (dim);
163- for (auto [name, shape] :
164- llvm::zip_equal (layout.getLabels (), layout.getShapes ())) {
165- if (isLaneDimension (name.getValue ()))
166- return shape;
167- }
168- return std::nullopt ;
169- }
170-
171- std::optional<LayoutDimension> LayoutAttr::getLane (int64_t dim) {
172- assert (dim < getRank ());
173- PerDimLayoutAttr layout = getDimLayout (dim);
174- for (auto [name, shape] :
175- llvm::zip_equal (layout.getLabels (), layout.getShapes ())) {
176- if (isLaneDimension (name.getValue ()))
177- return name.getValue ();
178- }
179- return std::nullopt ;
180- }
181-
182- int64_t LayoutAttr::getRank () const { return getLayouts ().size (); }
183-
184- std::tuple<int64_t , int64_t , int64_t > LayoutAttr::getLaneGrid () {
185- int64_t laneX = 1 ;
186- int64_t laneY = 1 ;
187- int64_t laneZ = 1 ;
188- for (PerDimLayoutAttr dimLayout : getLayouts ()) {
189- // Note that valid layouts only include at most one instance of each
190- // dimension type, so this is simply doing assignment on the first instance
191- // of each lane index, not an accumulative product.
192- auto maybeXShape = dimLayout.getShape (LayoutDimension::LANEX);
193- laneX *= maybeXShape.value_or (1 );
194- auto maybeYShape = dimLayout.getShape (LayoutDimension::LANEY);
195- laneY *= maybeYShape.value_or (1 );
196- auto maybeZShape = dimLayout.getShape (LayoutDimension::LANEZ);
197- laneZ *= maybeZShape.value_or (1 );
198- }
199- return std::make_tuple (laneX, laneY, laneZ);
200- }
201-
202- uint64_t LayoutAttr::getShuffleOffset (int64_t reductionDim) {
203- uint64_t offset = 0 ;
204- std::optional<LayoutDimension> laneDim = getLane (reductionDim);
205- if (!laneDim)
206- return offset;
207- switch (laneDim.value ()) {
208- case LayoutDimension::LANEX:
209- offset = 1 ;
210- break ;
211- case LayoutDimension::LANEY:
212- offset = getShape (LayoutDimension::LANEX).value_or (0 );
213- break ;
214- case LayoutDimension::LANEZ:
215- offset = getShape (LayoutDimension::LANEX).value_or (0 ) *
216- getShape (LayoutDimension::LANEY).value_or (0 );
217- break ;
218- default :
219- assert (false && " Invalid dimension! Expected lane dimension" );
220- break ;
221- }
222- return offset;
223- }
224-
225- bool LayoutAttr::hasLaneConflictWith (const LayoutAttr &other) {
226- SmallVector<LayoutDimension> laneDims{
227- LayoutDimension::LANEX, LayoutDimension::LANEY, LayoutDimension::LANEZ};
228- for (LayoutDimension dim : laneDims) {
229- std::optional<int64_t > shape = getShape (dim);
230- std::optional<int64_t > otherShape = other.getShape (dim);
231- if ((shape && !otherShape) || (!shape && otherShape))
232- return true ;
233- if (shape && otherShape) {
234- if (shape.value () != otherShape.value ())
235- return true ;
236- }
237- }
238- return false ;
239- }
240-
24128// Project the nested layout. This take a mask on the dimensions of the vector
24229// associated with this layout and projects out those dimensions. This reduces
24330// the rank of the layout in the process.
0 commit comments