@@ -14,13 +14,18 @@ limitations under the License.
1414==============================================================================*/
1515
1616#include < cassert>
17+ #include < cstdint>
1718#include < memory>
1819#include < utility>
1920
21+ #include " llvm/ADT/ArrayRef.h"
22+ #include " mlir/Dialect/Arith/IR/Arith.h"
2023#include " mlir/Dialect/Func/IR/FuncOps.h"
2124#include " mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
2225#include " mlir/Dialect/Vector/IR/VectorOps.h"
2326#include " mlir/IR/AffineExpr.h"
27+ #include " mlir/IR/Attributes.h"
28+ #include " mlir/IR/Builders.h"
2429#include " mlir/IR/BuiltinAttributes.h"
2530#include " mlir/IR/BuiltinOps.h"
2631#include " mlir/IR/BuiltinTypes.h"
@@ -42,6 +47,158 @@ namespace xla::cpu {
4247
4348namespace {
4449
50+ mlir::VectorType GetVectorType (mlir::RankedTensorType tensor_type) {
51+ return mlir::VectorType::get (tensor_type.getShape (),
52+ tensor_type.getElementType ());
53+ }
54+
55+ mlir::TypedValue<mlir::VectorType> CastToVector (
56+ mlir::PatternRewriter& rewriter,
57+ mlir::TypedValue<mlir::RankedTensorType> tensor_value) {
58+ auto vector_type = GetVectorType (tensor_value.getType ());
59+ auto cast_op = rewriter.create <mlir::UnrealizedConversionCastOp>(
60+ tensor_value.getLoc (), vector_type, tensor_value);
61+ return mlir::cast<mlir::TypedValue<mlir::VectorType>>(cast_op.getResult (0 ));
62+ }
63+
64+ mlir::AffineMapAttr GetOperandIndexingMap (
65+ mlir::OpBuilder& builder, int64_t iterator_count, int64_t rank,
66+ llvm::ArrayRef<int64_t > batch_dims,
67+ llvm::ArrayRef<int64_t > contracting_dims, int64_t free_dim_offset) {
68+ llvm::SmallVector<unsigned > targets (rank, -1 );
69+ unsigned idx = 0 ;
70+ for (int64_t dim : batch_dims) {
71+ targets[dim] = idx++;
72+ }
73+ for (int64_t dim : contracting_dims) {
74+ targets[dim] = idx++;
75+ }
76+ for (unsigned & target : targets) {
77+ if (target == -1 ) {
78+ target = free_dim_offset + idx++;
79+ }
80+ }
81+ auto affine_map = mlir::AffineMap::getMultiDimMapWithTargets (
82+ iterator_count, targets, builder.getContext ());
83+
84+ return mlir::AffineMapAttr::get (affine_map);
85+ }
86+
87+ mlir::AffineMapAttr GetOutputIndexingMap (mlir::OpBuilder& builder,
88+ int64_t iterator_count,
89+ int64_t batch_dim_count,
90+ int64_t contracting_dim_count) {
91+ llvm::SmallVector<unsigned > targets (iterator_count - contracting_dim_count);
92+ unsigned idx = 0 ;
93+ for (int64_t dim = 0 ; dim != batch_dim_count; ++dim) {
94+ targets[dim] = idx++;
95+ }
96+ idx += contracting_dim_count;
97+ int64_t total_free_dims =
98+ iterator_count - batch_dim_count - contracting_dim_count;
99+ for (int64_t dim = 0 ; dim != total_free_dims; ++dim) {
100+ targets[batch_dim_count + dim] = idx++;
101+ }
102+ auto affine_map = mlir::AffineMap::getMultiDimMapWithTargets (
103+ iterator_count, targets, builder.getContext ());
104+
105+ return mlir::AffineMapAttr::get (affine_map);
106+ }
107+
108+ mlir::ArrayAttr GetIteratorTypes (mlir::OpBuilder& builder,
109+ int64_t iterator_count,
110+ int64_t batch_dim_count,
111+ int64_t contracting_dim_count) {
112+ llvm::SmallVector<mlir::Attribute> iterator_types;
113+ iterator_types.reserve (iterator_count);
114+ for (int64_t dim = 0 ; dim != batch_dim_count; ++dim) {
115+ iterator_types.push_back (builder.getAttr <mlir::vector::IteratorTypeAttr>(
116+ mlir::vector::IteratorType::parallel));
117+ }
118+ for (int64_t dim = 0 ; dim != contracting_dim_count; ++dim) {
119+ iterator_types.push_back (builder.getAttr <mlir::vector::IteratorTypeAttr>(
120+ mlir::vector::IteratorType::reduction));
121+ }
122+ int64_t free_dims = iterator_count - batch_dim_count - contracting_dim_count;
123+ for (int64_t dim = 0 ; dim != free_dims; ++dim) {
124+ iterator_types.push_back (builder.getAttr <mlir::vector::IteratorTypeAttr>(
125+ mlir::vector::IteratorType::parallel));
126+ }
127+
128+ return mlir::ArrayAttr::get (builder.getContext (), iterator_types);
129+ }
130+
131+ // Lowers from stablehlo.dot_general to vector.contract.
132+ // The vector contract is very general as described here:
133+ // https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop
134+ // In this lowering the iteration order attribute passed is of the form:
135+ // (batch..., contracting..., free_lhs..., free_rhs...)
136+ // TODO(willfroom): Check if there is any performance impact on the order.
137+ struct LowerDotGeneral : mlir::OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
138+ using OpRewritePattern::OpRewritePattern;
139+
140+ mlir::LogicalResult matchAndRewrite (
141+ mlir::stablehlo::DotGeneralOp op,
142+ mlir::PatternRewriter& rewriter) const override {
143+ auto lhs_vector = CastToVector (rewriter, op.getLhs ());
144+ auto lhs_rank = lhs_vector.getType ().getRank ();
145+
146+ auto rhs_vector = CastToVector (rewriter, op.getRhs ());
147+ auto rhs_rank = rhs_vector.getType ().getRank ();
148+
149+ auto result_vector_type = GetVectorType (op.getResult ().getType ());
150+ auto zero_const = rewriter.create <mlir::arith::ConstantOp>(
151+ op->getLoc (), result_vector_type.getElementType (),
152+ rewriter.getZeroAttr (result_vector_type.getElementType ()));
153+ // TODO(willfroom): Ensure this is being folded into the accumilator in the
154+ // dot loop.
155+ mlir::Value accumulator = rewriter.create <mlir::vector::BroadcastOp>(
156+ op->getLoc (), result_vector_type, zero_const);
157+
158+ mlir::stablehlo::DotDimensionNumbersAttr dimension_numbers =
159+ op.getDotDimensionNumbers ();
160+
161+ llvm::ArrayRef<int64_t > lhs_batch =
162+ dimension_numbers.getLhsBatchingDimensions ();
163+ llvm::ArrayRef<int64_t > lhs_contracting =
164+ dimension_numbers.getLhsContractingDimensions ();
165+
166+ llvm::ArrayRef<int64_t > rhs_batch =
167+ dimension_numbers.getRhsBatchingDimensions ();
168+ llvm::ArrayRef<int64_t > rhs_contracting =
169+ dimension_numbers.getRhsContractingDimensions ();
170+
171+ int64_t lhs_free_dims =
172+ lhs_rank - lhs_batch.size () - lhs_contracting.size ();
173+ int64_t rhs_free_dims =
174+ rhs_rank - rhs_batch.size () - rhs_contracting.size ();
175+ int64_t iterator_count = lhs_batch.size () + lhs_contracting.size () +
176+ lhs_free_dims + rhs_free_dims;
177+
178+ mlir::Attribute lhs_indexing_map = GetOperandIndexingMap (
179+ rewriter, iterator_count, lhs_rank, lhs_batch, lhs_contracting, 0 );
180+ mlir::Attribute rhs_indexing_map =
181+ GetOperandIndexingMap (rewriter, iterator_count, rhs_rank, rhs_batch,
182+ rhs_contracting, lhs_free_dims);
183+ mlir::Attribute output_indexing_map = GetOutputIndexingMap (
184+ rewriter, iterator_count, lhs_batch.size (), lhs_contracting.size ());
185+
186+ mlir::ArrayAttr indexing_maps = rewriter.getArrayAttr (
187+ {lhs_indexing_map, rhs_indexing_map, output_indexing_map});
188+ mlir::ArrayAttr iterator_types = GetIteratorTypes (
189+ rewriter, iterator_count, lhs_batch.size (), lhs_contracting.size ());
190+
191+ mlir::Value result_vector = rewriter.create <mlir::vector::ContractionOp>(
192+ op->getLoc (), lhs_vector, rhs_vector, accumulator, indexing_maps,
193+ iterator_types);
194+
195+ rewriter.replaceOpWithNewOp <mlir::UnrealizedConversionCastOp>(
196+ op, op.getResult ().getType (), result_vector);
197+
198+ return mlir::success ();
199+ }
200+ };
201+
45202struct LowerTranspose : mlir::OpRewritePattern<mlir::stablehlo::TransposeOp> {
46203 using OpRewritePattern::OpRewritePattern;
47204
@@ -79,7 +236,7 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
79236 void runOnOperation () override {
80237 mlir::MLIRContext* context = &getContext ();
81238 mlir::RewritePatternSet patterns (context);
82- patterns.add <LowerTranspose>(context);
239+ patterns.add <LowerTranspose, LowerDotGeneral >(context);
83240 if (mlir::failed (
84241 mlir::applyPatternsGreedily (getOperation (), std::move (patterns)))) {
85242 signalPassFailure ();
0 commit comments