Skip to content

Commit 189c85b

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XLA:CPU][XTile] Create simple lowering for tiled ops.
PiperOrigin-RevId: 820160792
1 parent c039b31 commit 189c85b

22 files changed

+1953
-47
lines changed

xla/backends/cpu/codegen/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,15 @@ cc_library(
146146
"//xla:util",
147147
"//xla/backends/cpu/codegen/emitters/ir:xla_cpu",
148148
"//xla/backends/cpu/codegen/emitters/transforms:passes",
149+
"//xla/backends/cpu/codegen/tiled/transforms:passes",
149150
"//xla/codegen:llvm_ir_kernel_source",
150151
"//xla/codegen:mlir_kernel_source",
151152
"//xla/codegen:trace_pass_instrumentation",
152153
"//xla/codegen/emitters/ir:xla",
153154
"//xla/codegen/emitters/ir:xla_attrs_inc_gen",
154155
"//xla/codegen/emitters/transforms:pass_pipelines",
155156
"//xla/codegen/emitters/transforms:passes",
157+
"//xla/codegen/xtile/ir:xtile",
156158
"//xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc",
157159
"//xla/mlir_hlo",
158160
"//xla/service/gpu/model/experimental:symbolic_expr",
@@ -182,6 +184,7 @@ cc_library(
182184
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
183185
"@llvm-project//mlir:MathDialect",
184186
"@llvm-project//mlir:MathToLLVM",
187+
"@llvm-project//mlir:MemRefToLLVM",
185188
"@llvm-project//mlir:MemRefTransforms",
186189
"@llvm-project//mlir:Pass",
187190
"@llvm-project//mlir:ReconcileUnrealizedCasts",
@@ -191,7 +194,10 @@ cc_library(
191194
"@llvm-project//mlir:TensorDialect",
192195
"@llvm-project//mlir:ToLLVMIRTranslation",
193196
"@llvm-project//mlir:Transforms",
197+
"@llvm-project//mlir:UBToLLVM",
194198
"@llvm-project//mlir:VectorDialect",
199+
"@llvm-project//mlir:VectorToLLVM",
200+
"@llvm-project//mlir:VectorToSCF",
195201
"@stablehlo//:stablehlo_passes",
196202
"@tsl//tsl/profiler/lib:traceme",
197203
"@tsl//tsl/profiler/lib:traceme_encode",

xla/backends/cpu/codegen/fusion_compiler.cc

Lines changed: 109 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@ limitations under the License.
3838
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
3939
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
4040
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
41+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
4142
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
4243
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
44+
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
45+
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
46+
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
47+
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
4348
#include "mlir/Dialect/Affine/IR/AffineOps.h"
4449
#include "mlir/Dialect/Affine/Passes.h"
4550
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -58,6 +63,7 @@ limitations under the License.
5863
#include "mlir/IR/Attributes.h"
5964
#include "mlir/IR/BuiltinAttributes.h"
6065
#include "mlir/IR/BuiltinOps.h"
66+
#include "mlir/IR/Operation.h"
6167
#include "mlir/IR/Visitors.h"
6268
#include "mlir/Pass/PassManager.h"
6369
#include "mlir/Support/LLVM.h"
@@ -70,6 +76,7 @@ limitations under the License.
7076
#include "xla/backends/cpu/codegen/emitters/ir/xla_cpu_dialect.h"
7177
#include "xla/backends/cpu/codegen/emitters/transforms/passes.h"
7278
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
79+
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
7380
#include "xla/codegen/emitters/ir/xla_attrs.h.inc"
7481
#include "xla/codegen/emitters/ir/xla_dialect.h"
7582
#include "xla/codegen/emitters/ir/xla_ops.h"
@@ -78,6 +85,8 @@ limitations under the License.
7885
#include "xla/codegen/llvm_ir_kernel_source.h"
7986
#include "xla/codegen/mlir_kernel_source.h"
8087
#include "xla/codegen/trace_pass_instrumentation.h"
88+
#include "xla/codegen/xtile/ir/xtile_dialect.h"
89+
#include "xla/codegen/xtile/ir/xtile_ops.h"
8190
#include "xla/mlir/tools/mlir_replay/public/compiler_trace.pb.h"
8291
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
8392
#include "xla/status_macros.h"
@@ -114,6 +123,34 @@ static std::unique_ptr<::mlir::Pass> CreateConvertMathToLLVMPass() {
114123
return mlir::createConvertMathToLLVMPass(options);
115124
}
116125

126+
// The final lowering passes common to both scalar and tiled kernels.
127+
// These passes are primarily responsible for lowering individual ops to
128+
// their LLVM equivalent.
129+
static void AddGenericLoweringPasses(mlir::OpPassManager& pm) {
130+
pm.addPass(emitters::CreateSimplifyAffinePass());
131+
pm.addPass(mlir::createCanonicalizerPass());
132+
133+
// simplify-affine lowers most affine.apply ops, but if it can't prove a
134+
// division or modulo is unsigned, affine.apply ops will remain.
135+
pm.addPass(mlir::createLowerAffinePass());
136+
137+
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
138+
pm.addPass(mlir::createSymbolDCEPass());
139+
pm.addPass(mlir::createCSEPass());
140+
141+
pm.addNestedPass<mlir::func::FuncOp>(cpu::CreateExpandFloatOpsPass());
142+
pm.addPass(emitters::CreateExpandFloatOpsPass(/*aproximate_tanh=*/false));
143+
pm.addPass(emitters::CreateEraseDeadFunctionsPass());
144+
pm.addPass(mlir::createLowerAffinePass());
145+
pm.addPass(mlir::createSCFToControlFlowPass());
146+
pm.addPass(emitters::CreateLowerXlaIntrinsicLibPass());
147+
pm.addNestedPass<mlir::func::FuncOp>(CreateConvertMathToLLVMPass());
148+
pm.addPass(emitters::CreateLowerToLLVMPass(/*target_type=*/"cpu"));
149+
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
150+
pm.addPass(mlir::createCanonicalizerPass());
151+
pm.addPass(mlir::createCSEPass());
152+
}
153+
117154
static std::unique_ptr<::mlir::Pass> CreateInlinerAndCsePass() {
118155
return mlir::createCompositeFixedPointPass(
119156
"Inliner", [](mlir::OpPassManager& pm) {
@@ -124,8 +161,12 @@ static std::unique_ptr<::mlir::Pass> CreateInlinerAndCsePass() {
124161
});
125162
}
126163

127-
static void AddLoopTransformationPasses(mlir::OpPassManager& pm,
164+
// Optimizations passes for the "hero" emitters, e.g. loop emitter.
165+
// It is expected that the input has a simple nested loop structure that works
166+
// on scalar instructions extracted/inserted from tensor types.
167+
static void AddScalarOptimizationPasses(mlir::OpPassManager& pm,
128168
int32_t vector_width) {
169+
emitters::RegisterOptimizationPasses(pm);
129170
pm.addPass(CreateAddReductionFastMathFlagsPass());
130171
pm.addPass(CreateInlinerAndCsePass());
131172
pm.addNestedPass<mlir::func::FuncOp>(CreatePeelWorkgroupLoopPass());
@@ -154,8 +195,12 @@ static void AddLoopTransformationPasses(mlir::OpPassManager& pm,
154195
pm.addNestedPass<mlir::func::FuncOp>(CreateAddLoopUnrollFlagsPass());
155196
}
156197

157-
static void AddLoweringPasses(mlir::OpPassManager& pm, int32_t vector_width,
158-
bool fast_min_max) {
198+
// Lowering passes for the "hero" emitters, e.g. loop emitter.
199+
// It is expected that the input has a simple nested loop structure that works
200+
// on scalar instructions extracted/inserted from tensor types.
201+
// The resulting IR can then be translated to native LLVM.
202+
static void AddScalarLoweringPasses(mlir::OpPassManager& pm,
203+
int32_t vector_width, bool fast_min_max) {
159204
pm.addNestedPass<mlir::func::FuncOp>(
160205
emitters::CreateConvertPureCallOpsPass());
161206
pm.addPass(cpu::createLowerToLLVMPass(
@@ -170,28 +215,32 @@ static void AddLoweringPasses(mlir::OpPassManager& pm, int32_t vector_width,
170215
pm.addPass(mlir::createCSEPass());
171216
pm.addNestedPass<mlir::func::FuncOp>(
172217
emitters::CreateSimplifyArithPass(fast_min_max));
173-
pm.addPass(emitters::CreateSimplifyAffinePass());
174-
pm.addPass(mlir::createCanonicalizerPass());
218+
AddGenericLoweringPasses(pm);
219+
}
175220

176-
// simplify-affine lowers most affine.apply ops, but if it can't prove a
177-
// division or modulo is unsigned, affine.apply ops will remain.
178-
pm.addPass(mlir::createLowerAffinePass());
221+
// Optimizations passes for the tiled emitter.
222+
// This is currently very simple but will grow to include tiled optimizations
223+
// such as transpose hoisting and dimension reduction.
224+
static void AddTiledOptimizationPasses(mlir::OpPassManager& pm) {
225+
emitters::RegisterOptimizationPasses(pm);
226+
}
179227

180-
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
181-
pm.addPass(mlir::createSymbolDCEPass());
182-
pm.addPass(mlir::createCSEPass());
228+
// Lowering passes for the tiled emitter.
229+
// The input IR is from the xtile dialect which uses tensors that are converted
230+
// first to the vector dialect and then to LLVM.
231+
static void AddTiledLoweringPasses(mlir::OpPassManager& pm) {
232+
pm.addPass(CreateXTileToVectorPass());
233+
pm.addPass(CreateElementalTensorToVectorPass());
234+
pm.addPass(CreateShloToVectorPass());
235+
pm.addPass(CreateLowerXTileEntryPass());
236+
pm.addPass(cpu::createLowerToLLVMPass());
237+
pm.addPass(mlir::createConvertVectorToSCFPass(
238+
mlir::VectorTransferToSCFOptions().enableFullUnroll(false)));
239+
pm.addPass(mlir::createConvertVectorToLLVMPass());
183240

184-
pm.addNestedPass<mlir::func::FuncOp>(cpu::CreateExpandFloatOpsPass());
185-
pm.addPass(emitters::CreateExpandFloatOpsPass(/*aproximate_tanh=*/false));
186-
pm.addPass(emitters::CreateEraseDeadFunctionsPass());
187-
pm.addPass(mlir::createLowerAffinePass());
188-
pm.addPass(mlir::createSCFToControlFlowPass());
189-
pm.addPass(emitters::CreateLowerXlaIntrinsicLibPass());
190-
pm.addNestedPass<mlir::func::FuncOp>(CreateConvertMathToLLVMPass());
191-
pm.addPass(emitters::CreateLowerToLLVMPass(/*target_type=*/"cpu"));
192-
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
193-
pm.addPass(mlir::createCanonicalizerPass());
194-
pm.addPass(mlir::createCSEPass());
241+
pm.addPass(mlir::createConvertComplexToStandardPass());
242+
243+
AddGenericLoweringPasses(pm);
195244
}
196245

197246
static int GetLlvmFunctionDefCount(mlir::ModuleOp m) {
@@ -223,18 +272,31 @@ FusionCompiler::FusionCompiler(mlir::MLIRContext* context, Options options,
223272
CompilationHooks hooks)
224273
: options_(std::move(options)),
225274
hooks_(std::move(hooks)),
226-
optimization_pass_manager_(
275+
scalar_optimization_pass_manager_(
276+
mlir::PassManager::on<mlir::ModuleOp>(context)),
277+
tiled_optimization_pass_manager_(
278+
mlir::PassManager::on<mlir::ModuleOp>(context)),
279+
scalar_lowering_pass_manager_(
227280
mlir::PassManager::on<mlir::ModuleOp>(context)),
228-
lowering_pass_manager_(mlir::PassManager::on<mlir::ModuleOp>(context)) {
229-
emitters::RegisterOptimizationPasses(optimization_pass_manager_);
230-
AddLoopTransformationPasses(optimization_pass_manager_,
281+
tiled_lowering_pass_manager_(
282+
mlir::PassManager::on<mlir::ModuleOp>(context)) {
283+
// Scalar passes.
284+
AddScalarOptimizationPasses(scalar_optimization_pass_manager_,
231285
options_.vector_width);
232-
optimization_pass_manager_.addInstrumentation(
233-
std::make_unique<TraceInstrumentation>());
286+
AddScalarLoweringPasses(scalar_lowering_pass_manager_, options_.vector_width,
287+
options_.fast_min_max);
288+
289+
// Tiled passes.
290+
AddTiledOptimizationPasses(tiled_optimization_pass_manager_);
291+
AddTiledLoweringPasses(tiled_lowering_pass_manager_);
234292

235-
AddLoweringPasses(lowering_pass_manager_, options_.vector_width,
236-
options_.fast_min_max);
237-
lowering_pass_manager_.addInstrumentation(
293+
scalar_optimization_pass_manager_.addInstrumentation(
294+
std::make_unique<TraceInstrumentation>());
295+
scalar_lowering_pass_manager_.addInstrumentation(
296+
std::make_unique<TraceInstrumentation>());
297+
tiled_optimization_pass_manager_.addInstrumentation(
298+
std::make_unique<TraceInstrumentation>());
299+
tiled_lowering_pass_manager_.addInstrumentation(
238300
std::make_unique<TraceInstrumentation>());
239301
}
240302

@@ -252,6 +314,14 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
252314
});
253315
return count;
254316
};
317+
318+
bool is_tiled = !mlir_module.getBody()->getOps<xtile::EntryFuncOp>().empty();
319+
mlir::PassManager& optimization_pm = is_tiled
320+
? tiled_optimization_pass_manager_
321+
: scalar_optimization_pass_manager_;
322+
mlir::PassManager& lowering_pm =
323+
is_tiled ? tiled_lowering_pass_manager_ : scalar_lowering_pass_manager_;
324+
255325
VLOG(1) << "Compiling MLIR module: " << module_name << ", with "
256326
<< get_module_op_count() << " operations.";
257327
XLA_SCOPED_LOGGING_TIMER_LEVEL(
@@ -266,15 +336,15 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
266336
if (hooks_.pre_optimization) {
267337
hooks_.pre_optimization(mlir_module);
268338
}
269-
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, optimization_pass_manager_,
270-
nullptr, options_.verification_level));
339+
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, optimization_pm, nullptr,
340+
options_.verification_level));
271341

272342
if (hooks_.post_optimization) {
273343
hooks_.post_optimization(mlir_module);
274344
}
275345

276-
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, lowering_pass_manager_,
277-
nullptr, options_.verification_level));
346+
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, lowering_pm, nullptr,
347+
options_.verification_level));
278348

279349
if (hooks_.post_lowering) {
280350
hooks_.post_lowering(mlir_module);
@@ -347,14 +417,17 @@ std::unique_ptr<mlir::MLIRContext> FusionCompiler::CreateContext() {
347417
xla::cpu::XlaCpuDialect, mlir::mhlo::MhloDialect,
348418
mlir::scf::SCFDialect, mlir::LLVM::LLVMDialect,
349419
mlir::tensor::TensorDialect, mlir::vector::VectorDialect,
350-
xla::XlaDialect>();
420+
xla::XlaDialect, xla::xtile::XTileDialect>();
351421

352422
mlir::DialectRegistry registry;
353423
mlir::LLVM::registerInlinerInterface(registry);
354424
mlir::func::registerInlinerExtension(registry);
355425
mlir::registerLLVMDialectTranslation(registry);
356426
mlir::registerBuiltinDialectTranslation(registry);
357427
mlir::registerConvertMathToLLVMInterface(registry);
428+
mlir::registerConvertMemRefToLLVMInterface(registry);
429+
mlir::ub::registerConvertUBToLLVMInterface(registry);
430+
mlir::vector::registerConvertVectorToLLVMInterface(registry);
358431
context->appendDialectRegistry(registry);
359432

360433
return context;

xla/backends/cpu/codegen/fusion_compiler.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,19 @@ class FusionCompiler {
6767
private:
6868
Options options_;
6969
CompilationHooks hooks_;
70+
// The reason we have 4 distinct pass managers is because:
71+
// - We have 2 stages: optimization and lowering, this is to enable dumping
72+
// of the intermediate optimized MLIR.
73+
// - We have 2 distinct pipelines for scalar and tiled kernels, this is
74+
// because they differ slightly in their semantics, ideally these would be
75+
// unified but this is a larger change.
7076
// Pass manager that holds the optimization & loop transformation passes.
71-
mlir::PassManager optimization_pass_manager_;
77+
mlir::PassManager scalar_optimization_pass_manager_;
78+
mlir::PassManager tiled_optimization_pass_manager_;
7279
// Pass manager that holds the passes responsible for lowering the module from
7380
// MLIR to LLVM.
74-
mlir::PassManager lowering_pass_manager_;
81+
mlir::PassManager scalar_lowering_pass_manager_;
82+
mlir::PassManager tiled_lowering_pass_manager_;
7583
};
7684

7785
} // namespace xla::cpu
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("//xla:py_strict.bzl", "py_strict_test")
2+
3+
package(
4+
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
5+
licenses = ["notice"],
6+
)
7+
8+
py_strict_test(
9+
name = "tiled_kernel_test",
10+
srcs = ["tiled_kernel_test.py"],
11+
main = "tiled_kernel_test.py",
12+
tags = [
13+
"no_oss",
14+
],
15+
deps = [
16+
"//third_party/py/numpy",
17+
"//xla:xla_data_proto_py",
18+
"//xla/backends/cpu/testlib",
19+
"//xla/codegen/testlib",
20+
"@absl_py//absl/testing:absltest",
21+
],
22+
)

0 commit comments

Comments
 (0)