Skip to content

Commit 832c86a

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XLA:CPU][XTile] Add lowering for StableHLO DotGeneral.
PiperOrigin-RevId: 820214413
1 parent 14a5144 commit 832c86a

File tree

3 files changed

+245
-4
lines changed

3 files changed

+245
-4
lines changed

xla/backends/cpu/codegen/tiled/tiled_kernel_test.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def compare_kernel(
3333
output_shape: tuple[int, ...],
3434
dtype,
3535
expected_output: Callable[[np.ndarray, ...], np.ndarray],
36+
exact: bool = True,
3637
) -> None:
3738
mlir_emitter = cpu_testlib.MlirTestKernelEmitter(
3839
ir, kernel_name, (num_workgroups, 1, 1)
@@ -49,9 +50,14 @@ def compare_kernel(
4950
output_tensor = create_literal(np.zeros(output_shape, dtype=dtype))
5051
runner.call(input_tensors + [output_tensor])
5152

52-
np.testing.assert_array_equal(
53-
np.asarray(output_tensor), expected_output(*inputs)
54-
)
53+
if exact:
54+
np.testing.assert_array_equal(
55+
np.asarray(output_tensor), expected_output(*inputs)
56+
)
57+
else:
58+
np.testing.assert_array_almost_equal(
59+
np.asarray(output_tensor), expected_output(*inputs)
60+
)
5561

5662

5763
class XtileLoweringTest(absltest.TestCase):
@@ -139,6 +145,68 @@ def test_add_tranpose(self):
139145
lambda arg: arg + arg.transpose(),
140146
)
141147

148+
def test_dot_single_tile(self):
149+
ir = """
150+
module @dot_single_tile {
151+
xtile.entry_func @dot_single_tile(
152+
%lhs: memref<8x16xf32>,
153+
%rhs: memref<16x8xf32>,
154+
%output: memref<8x8xf32>,
155+
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
156+
%offset = arith.constant 0 : index
157+
%lhs_tile = xtile.extract %lhs[%offset, %offset][8, 16][1, 1] : memref<8x16xf32> -> tensor<8x16xf32>
158+
%rhs_tile = xtile.extract %rhs[%offset, %offset][16, 8][1, 1] : memref<16x8xf32> -> tensor<16x8xf32>
159+
%result = stablehlo.dot_general %lhs_tile, %rhs_tile, contracting_dims = [1] x [0] : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32>
160+
xtile.insert %result into %output[%offset, %offset][8, 8][1, 1] : tensor<8x8xf32> -> memref<8x8xf32>
161+
xtile.return
162+
}
163+
}
164+
"""
165+
166+
compare_kernel(
167+
ir,
168+
"dot_single_tile",
169+
1,
170+
[(8, 16), (16, 8)],
171+
(8, 8),
172+
np.float32,
173+
lambda lhs, rhs: lhs @ rhs,
174+
False,
175+
)
176+
177+
def test_dot_fusion_single_tile(self):
178+
ir = """
179+
module @dot_fusion_single_tile {
180+
xtile.entry_func @dot_fusion_single_tile(
181+
%lhs_0: memref<8x16xf32>,
182+
%lhs_1: memref<8x16xf32>,
183+
%rhs: memref<16x1xf32>,
184+
%output: memref<8x1xf32>,
185+
%tile_id: index) attributes {xtile.tiling_info = #xtile.tiling_info<tile_count:1, tiles_per_workgroup:1>} {
186+
%offset = arith.constant 0 : index
187+
%lhs_0_tile = xtile.extract %lhs_0[%offset, %offset][8, 16][1, 1] : memref<8x16xf32> -> tensor<8x16xf32>
188+
%lhs_1_tile = xtile.extract %lhs_1[%offset, %offset][8, 16][1, 1] : memref<8x16xf32> -> tensor<8x16xf32>
189+
%add_lhs = arith.addf %lhs_0_tile, %lhs_1_tile : tensor<8x16xf32>
190+
%rhs_tile = xtile.extract %rhs[%offset, %offset][16, 1][1, 1] : memref<16x1xf32> -> tensor<16xf32>
191+
%result = stablehlo.dot_general %add_lhs, %rhs_tile, contracting_dims = [1] x [0] : (tensor<8x16xf32>, tensor<16xf32>) -> tensor<8xf32>
192+
%tanh_result = math.tanh %result : tensor<8xf32>
193+
xtile.insert %tanh_result into %output[%offset, %offset][8, 1][1, 1] : tensor<8xf32> -> memref<8x1xf32>
194+
xtile.return
195+
}
196+
}
197+
"""
198+
199+
compare_kernel(
200+
ir,
201+
"dot_fusion_single_tile",
202+
1,
203+
[(8, 16), (8, 16), (16, 1)],
204+
(8, 1),
205+
np.float32,
206+
lambda lhs_0, lhs_1, rhs: np.tanh((lhs_0 + lhs_1) @ rhs),
207+
False,
208+
)
209+
142210

143211
if __name__ == "__main__":
144212
absltest.main()

xla/backends/cpu/codegen/tiled/transforms/shlo_to_vector.cc

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,18 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include <cassert>
17+
#include <cstdint>
1718
#include <memory>
1819
#include <utility>
1920

21+
#include "llvm/ADT/ArrayRef.h"
22+
#include "mlir/Dialect/Arith/IR/Arith.h"
2023
#include "mlir/Dialect/Func/IR/FuncOps.h"
2124
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep
2225
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2326
#include "mlir/IR/AffineExpr.h"
27+
#include "mlir/IR/Attributes.h"
28+
#include "mlir/IR/Builders.h"
2429
#include "mlir/IR/BuiltinAttributes.h"
2530
#include "mlir/IR/BuiltinOps.h"
2631
#include "mlir/IR/BuiltinTypes.h"
@@ -42,6 +47,158 @@ namespace xla::cpu {
4247

4348
namespace {
4449

50+
mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type) {
51+
return mlir::VectorType::get(tensor_type.getShape(),
52+
tensor_type.getElementType());
53+
}
54+
55+
mlir::TypedValue<mlir::VectorType> CastToVector(
56+
mlir::PatternRewriter& rewriter,
57+
mlir::TypedValue<mlir::RankedTensorType> tensor_value) {
58+
auto vector_type = GetVectorType(tensor_value.getType());
59+
auto cast_op = rewriter.create<mlir::UnrealizedConversionCastOp>(
60+
tensor_value.getLoc(), vector_type, tensor_value);
61+
return mlir::cast<mlir::TypedValue<mlir::VectorType>>(cast_op.getResult(0));
62+
}
63+
64+
mlir::AffineMapAttr GetOperandIndexingMap(
65+
mlir::OpBuilder& builder, int64_t iterator_count, int64_t rank,
66+
llvm::ArrayRef<int64_t> batch_dims,
67+
llvm::ArrayRef<int64_t> contracting_dims, int64_t free_dim_offset) {
68+
llvm::SmallVector<unsigned> targets(rank, -1);
69+
unsigned idx = 0;
70+
for (int64_t dim : batch_dims) {
71+
targets[dim] = idx++;
72+
}
73+
for (int64_t dim : contracting_dims) {
74+
targets[dim] = idx++;
75+
}
76+
for (unsigned& target : targets) {
77+
if (target == -1) {
78+
target = free_dim_offset + idx++;
79+
}
80+
}
81+
auto affine_map = mlir::AffineMap::getMultiDimMapWithTargets(
82+
iterator_count, targets, builder.getContext());
83+
84+
return mlir::AffineMapAttr::get(affine_map);
85+
}
86+
87+
mlir::AffineMapAttr GetOutputIndexingMap(mlir::OpBuilder& builder,
88+
int64_t iterator_count,
89+
int64_t batch_dim_count,
90+
int64_t contracting_dim_count) {
91+
llvm::SmallVector<unsigned> targets(iterator_count - contracting_dim_count);
92+
unsigned idx = 0;
93+
for (int64_t dim = 0; dim != batch_dim_count; ++dim) {
94+
targets[dim] = idx++;
95+
}
96+
idx += contracting_dim_count;
97+
int64_t total_free_dims =
98+
iterator_count - batch_dim_count - contracting_dim_count;
99+
for (int64_t dim = 0; dim != total_free_dims; ++dim) {
100+
targets[batch_dim_count + dim] = idx++;
101+
}
102+
auto affine_map = mlir::AffineMap::getMultiDimMapWithTargets(
103+
iterator_count, targets, builder.getContext());
104+
105+
return mlir::AffineMapAttr::get(affine_map);
106+
}
107+
108+
mlir::ArrayAttr GetIteratorTypes(mlir::OpBuilder& builder,
109+
int64_t iterator_count,
110+
int64_t batch_dim_count,
111+
int64_t contracting_dim_count) {
112+
llvm::SmallVector<mlir::Attribute> iterator_types;
113+
iterator_types.reserve(iterator_count);
114+
for (int64_t dim = 0; dim != batch_dim_count; ++dim) {
115+
iterator_types.push_back(builder.getAttr<mlir::vector::IteratorTypeAttr>(
116+
mlir::vector::IteratorType::parallel));
117+
}
118+
for (int64_t dim = 0; dim != contracting_dim_count; ++dim) {
119+
iterator_types.push_back(builder.getAttr<mlir::vector::IteratorTypeAttr>(
120+
mlir::vector::IteratorType::reduction));
121+
}
122+
int64_t free_dims = iterator_count - batch_dim_count - contracting_dim_count;
123+
for (int64_t dim = 0; dim != free_dims; ++dim) {
124+
iterator_types.push_back(builder.getAttr<mlir::vector::IteratorTypeAttr>(
125+
mlir::vector::IteratorType::parallel));
126+
}
127+
128+
return mlir::ArrayAttr::get(builder.getContext(), iterator_types);
129+
}
130+
131+
// Lowers from stablehlo.dot_general to vector.contract.
132+
// The vector contract is very general as described here:
133+
// https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop
134+
// In this lowering the iteration order attribute passed is of the form:
135+
// (batch..., contracting..., free_lhs..., free_rhs...)
136+
// TODO(willfroom): Check if there is any performance impact on the order.
137+
struct LowerDotGeneral : mlir::OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
138+
using OpRewritePattern::OpRewritePattern;
139+
140+
mlir::LogicalResult matchAndRewrite(
141+
mlir::stablehlo::DotGeneralOp op,
142+
mlir::PatternRewriter& rewriter) const override {
143+
auto lhs_vector = CastToVector(rewriter, op.getLhs());
144+
auto lhs_rank = lhs_vector.getType().getRank();
145+
146+
auto rhs_vector = CastToVector(rewriter, op.getRhs());
147+
auto rhs_rank = rhs_vector.getType().getRank();
148+
149+
auto result_vector_type = GetVectorType(op.getResult().getType());
150+
auto zero_const = rewriter.create<mlir::arith::ConstantOp>(
151+
op->getLoc(), result_vector_type.getElementType(),
152+
rewriter.getZeroAttr(result_vector_type.getElementType()));
153+
// TODO(willfroom): Ensure this is being folded into the accumilator in the
154+
// dot loop.
155+
mlir::Value accumulator = rewriter.create<mlir::vector::BroadcastOp>(
156+
op->getLoc(), result_vector_type, zero_const);
157+
158+
mlir::stablehlo::DotDimensionNumbersAttr dimension_numbers =
159+
op.getDotDimensionNumbers();
160+
161+
llvm::ArrayRef<int64_t> lhs_batch =
162+
dimension_numbers.getLhsBatchingDimensions();
163+
llvm::ArrayRef<int64_t> lhs_contracting =
164+
dimension_numbers.getLhsContractingDimensions();
165+
166+
llvm::ArrayRef<int64_t> rhs_batch =
167+
dimension_numbers.getRhsBatchingDimensions();
168+
llvm::ArrayRef<int64_t> rhs_contracting =
169+
dimension_numbers.getRhsContractingDimensions();
170+
171+
int64_t lhs_free_dims =
172+
lhs_rank - lhs_batch.size() - lhs_contracting.size();
173+
int64_t rhs_free_dims =
174+
rhs_rank - rhs_batch.size() - rhs_contracting.size();
175+
int64_t iterator_count = lhs_batch.size() + lhs_contracting.size() +
176+
lhs_free_dims + rhs_free_dims;
177+
178+
mlir::Attribute lhs_indexing_map = GetOperandIndexingMap(
179+
rewriter, iterator_count, lhs_rank, lhs_batch, lhs_contracting, 0);
180+
mlir::Attribute rhs_indexing_map =
181+
GetOperandIndexingMap(rewriter, iterator_count, rhs_rank, rhs_batch,
182+
rhs_contracting, lhs_free_dims);
183+
mlir::Attribute output_indexing_map = GetOutputIndexingMap(
184+
rewriter, iterator_count, lhs_batch.size(), lhs_contracting.size());
185+
186+
mlir::ArrayAttr indexing_maps = rewriter.getArrayAttr(
187+
{lhs_indexing_map, rhs_indexing_map, output_indexing_map});
188+
mlir::ArrayAttr iterator_types = GetIteratorTypes(
189+
rewriter, iterator_count, lhs_batch.size(), lhs_contracting.size());
190+
191+
mlir::Value result_vector = rewriter.create<mlir::vector::ContractionOp>(
192+
op->getLoc(), lhs_vector, rhs_vector, accumulator, indexing_maps,
193+
iterator_types);
194+
195+
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
196+
op, op.getResult().getType(), result_vector);
197+
198+
return mlir::success();
199+
}
200+
};
201+
45202
struct LowerTranspose : mlir::OpRewritePattern<mlir::stablehlo::TransposeOp> {
46203
using OpRewritePattern::OpRewritePattern;
47204

@@ -79,7 +236,7 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
79236
void runOnOperation() override {
80237
mlir::MLIRContext* context = &getContext();
81238
mlir::RewritePatternSet patterns(context);
82-
patterns.add<LowerTranspose>(context);
239+
patterns.add<LowerTranspose, LowerDotGeneral>(context);
83240
if (mlir::failed(
84241
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
85242
signalPassFailure();

xla/backends/cpu/codegen/tiled/transforms/tests/shlo_to_vector.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,19 @@ func.func @transpose(%input : tensor<1024x32xf32>) -> tensor<32x1024xf32> {
66
return %transposed : tensor<32x1024xf32>
77
}
88
// -----
9+
10+
// CHECK-DAG: #[[LHS_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
11+
// CHECK-DAG: #[[RHS_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
12+
// CHECK-DAG: #[[OUTPUT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
13+
func.func @dot_general(%lhs : tensor<1024x32xf32>, %rhs : tensor<32x1024xf32>) -> tensor<1024x1024xf32> {
14+
// CHECK: %[[ACCUMULATOR:.*]] = arith.constant dense<0.000000e+00> : vector<1024x1024xf32>
15+
// CHECK: vector.contract
16+
// CHECK-SAME: {indexing_maps = [#[[LHS_MAP]], #[[RHS_MAP]], #[[OUTPUT_MAP]]],
17+
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"],
18+
// CHECK-SAME: kind = #vector.kind<add>}
19+
// CHECK-SAME: %[[ACCUMULATOR]] : vector<1024x32xf32>, vector<32x1024xf32> into vector<1024x1024xf32>
20+
%result = stablehlo.dot_general %lhs, %rhs, contracting_dims = [1] x [0] : (tensor<1024x32xf32>, tensor<32x1024xf32>) -> tensor<1024x1024xf32>
21+
return %result : tensor<1024x1024xf32>
22+
}
23+
24+
// -----

0 commit comments

Comments
 (0)