@@ -38,8 +38,13 @@ limitations under the License.
3838#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
3939#include " mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
4040#include " mlir/Conversion/MathToLLVM/MathToLLVM.h"
41+ #include " mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
4142#include " mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
4243#include " mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
44+ #include " mlir/Conversion/UBToLLVM/UBToLLVM.h"
45+ #include " mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
46+ #include " mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
47+ #include " mlir/Conversion/VectorToSCF/VectorToSCF.h"
4348#include " mlir/Dialect/Affine/IR/AffineOps.h"
4449#include " mlir/Dialect/Affine/Passes.h"
4550#include " mlir/Dialect/Arith/IR/Arith.h"
@@ -58,6 +63,7 @@ limitations under the License.
5863#include " mlir/IR/Attributes.h"
5964#include " mlir/IR/BuiltinAttributes.h"
6065#include " mlir/IR/BuiltinOps.h"
66+ #include " mlir/IR/Operation.h"
6167#include " mlir/IR/Visitors.h"
6268#include " mlir/Pass/PassManager.h"
6369#include " mlir/Support/LLVM.h"
@@ -70,6 +76,7 @@ limitations under the License.
7076#include " xla/backends/cpu/codegen/emitters/ir/xla_cpu_dialect.h"
7177#include " xla/backends/cpu/codegen/emitters/transforms/passes.h"
7278#include " xla/backends/cpu/codegen/kernel_api_ir_builder.h"
79+ #include " xla/backends/cpu/codegen/tiled/transforms/passes.h"
7380#include " xla/codegen/emitters/ir/xla_attrs.h.inc"
7481#include " xla/codegen/emitters/ir/xla_dialect.h"
7582#include " xla/codegen/emitters/ir/xla_ops.h"
@@ -78,6 +85,8 @@ limitations under the License.
7885#include " xla/codegen/llvm_ir_kernel_source.h"
7986#include " xla/codegen/mlir_kernel_source.h"
8087#include " xla/codegen/trace_pass_instrumentation.h"
88+ #include " xla/codegen/xtile/ir/xtile_dialect.h"
89+ #include " xla/codegen/xtile/ir/xtile_ops.h"
8190#include " xla/mlir/tools/mlir_replay/public/compiler_trace.pb.h"
8291#include " xla/mlir_hlo/mhlo/IR/hlo_ops.h"
8392#include " xla/status_macros.h"
@@ -114,6 +123,34 @@ static std::unique_ptr<::mlir::Pass> CreateConvertMathToLLVMPass() {
114123 return mlir::createConvertMathToLLVMPass (options);
115124}
116125
126+ // The final lowering passes common to both scalar and tiled kernels.
127+ // These passes are primarily responsible for lowering individual ops to
128+ // their LLVM equivalent.
129+ static void AddGenericLoweringPasses (mlir::OpPassManager& pm) {
130+ pm.addPass (emitters::CreateSimplifyAffinePass ());
131+ pm.addPass (mlir::createCanonicalizerPass ());
132+
133+ // simplify-affine lowers most affine.apply ops, but if it can't prove a
134+ // division or modulo is unsigned, affine.apply ops will remain.
135+ pm.addPass (mlir::createLowerAffinePass ());
136+
137+ pm.addPass (mlir::createLoopInvariantCodeMotionPass ());
138+ pm.addPass (mlir::createSymbolDCEPass ());
139+ pm.addPass (mlir::createCSEPass ());
140+
141+ pm.addNestedPass <mlir::func::FuncOp>(cpu::CreateExpandFloatOpsPass ());
142+ pm.addPass (emitters::CreateExpandFloatOpsPass (/* aproximate_tanh=*/ false ));
143+ pm.addPass (emitters::CreateEraseDeadFunctionsPass ());
144+ pm.addPass (mlir::createLowerAffinePass ());
145+ pm.addPass (mlir::createSCFToControlFlowPass ());
146+ pm.addPass (emitters::CreateLowerXlaIntrinsicLibPass ());
147+ pm.addNestedPass <mlir::func::FuncOp>(CreateConvertMathToLLVMPass ());
148+ pm.addPass (emitters::CreateLowerToLLVMPass (/* target_type=*/ " cpu" ));
149+ pm.addPass (mlir::createReconcileUnrealizedCastsPass ());
150+ pm.addPass (mlir::createCanonicalizerPass ());
151+ pm.addPass (mlir::createCSEPass ());
152+ }
153+
117154static std::unique_ptr<::mlir::Pass> CreateInlinerAndCsePass () {
118155 return mlir::createCompositeFixedPointPass (
119156 " Inliner" , [](mlir::OpPassManager& pm) {
@@ -124,8 +161,12 @@ static std::unique_ptr<::mlir::Pass> CreateInlinerAndCsePass() {
124161 });
125162}
126163
127- static void AddLoopTransformationPasses (mlir::OpPassManager& pm,
164+ // Optimizations passes for the "hero" emitters, e.g. loop emitter.
165+ // It is expected that the input has a simple nested loop structure that works
166+ // on scalar instructions extracted/inserted from tensor types.
167+ static void AddScalarOptimizationPasses (mlir::OpPassManager& pm,
128168 int32_t vector_width) {
169+ emitters::RegisterOptimizationPasses (pm);
129170 pm.addPass (CreateAddReductionFastMathFlagsPass ());
130171 pm.addPass (CreateInlinerAndCsePass ());
131172 pm.addNestedPass <mlir::func::FuncOp>(CreatePeelWorkgroupLoopPass ());
@@ -154,8 +195,12 @@ static void AddLoopTransformationPasses(mlir::OpPassManager& pm,
154195 pm.addNestedPass <mlir::func::FuncOp>(CreateAddLoopUnrollFlagsPass ());
155196}
156197
157- static void AddLoweringPasses (mlir::OpPassManager& pm, int32_t vector_width,
158- bool fast_min_max) {
198+ // Lowering passes for the "hero" emitters, e.g. loop emitter.
199+ // It is expected that the input has a simple nested loop structure that works
200+ // on scalar instructions extracted/inserted from tensor types.
201+ // The resulting IR can then be translated to native LLVM.
202+ static void AddScalarLoweringPasses (mlir::OpPassManager& pm,
203+ int32_t vector_width, bool fast_min_max) {
159204 pm.addNestedPass <mlir::func::FuncOp>(
160205 emitters::CreateConvertPureCallOpsPass ());
161206 pm.addPass (cpu::createLowerToLLVMPass (
@@ -170,28 +215,32 @@ static void AddLoweringPasses(mlir::OpPassManager& pm, int32_t vector_width,
170215 pm.addPass (mlir::createCSEPass ());
171216 pm.addNestedPass <mlir::func::FuncOp>(
172217 emitters::CreateSimplifyArithPass (fast_min_max));
173- pm. addPass ( emitters::CreateSimplifyAffinePass () );
174- pm. addPass ( mlir::createCanonicalizerPass ());
218+ AddGenericLoweringPasses (pm );
219+ }
175220
176- // simplify-affine lowers most affine.apply ops, but if it can't prove a
177- // division or modulo is unsigned, affine.apply ops will remain.
178- pm.addPass (mlir::createLowerAffinePass ());
221+ // Optimizations passes for the tiled emitter.
222+ // This is currently very simple but will grow to include tiled optimizations
223+ // such as transpose hoisting and dimension reduction.
224+ static void AddTiledOptimizationPasses (mlir::OpPassManager& pm) {
225+ emitters::RegisterOptimizationPasses (pm);
226+ }
179227
180- pm.addPass (mlir::createLoopInvariantCodeMotionPass ());
181- pm.addPass (mlir::createSymbolDCEPass ());
182- pm.addPass (mlir::createCSEPass ());
228+ // Lowering passes for the tiled emitter.
229+ // The input IR is from the xtile dialect which uses tensors that are converted
230+ // first to the vector dialect and then to LLVM.
231+ static void AddTiledLoweringPasses (mlir::OpPassManager& pm) {
232+ pm.addPass (CreateXTileToVectorPass ());
233+ pm.addPass (CreateElementalTensorToVectorPass ());
234+ pm.addPass (CreateShloToVectorPass ());
235+ pm.addPass (CreateLowerXTileEntryPass ());
236+ pm.addPass (cpu::createLowerToLLVMPass ());
237+ pm.addPass (mlir::createConvertVectorToSCFPass (
238+ mlir::VectorTransferToSCFOptions ().enableFullUnroll (false )));
239+ pm.addPass (mlir::createConvertVectorToLLVMPass ());
183240
184- pm.addNestedPass <mlir::func::FuncOp>(cpu::CreateExpandFloatOpsPass ());
185- pm.addPass (emitters::CreateExpandFloatOpsPass (/* aproximate_tanh=*/ false ));
186- pm.addPass (emitters::CreateEraseDeadFunctionsPass ());
187- pm.addPass (mlir::createLowerAffinePass ());
188- pm.addPass (mlir::createSCFToControlFlowPass ());
189- pm.addPass (emitters::CreateLowerXlaIntrinsicLibPass ());
190- pm.addNestedPass <mlir::func::FuncOp>(CreateConvertMathToLLVMPass ());
191- pm.addPass (emitters::CreateLowerToLLVMPass (/* target_type=*/ " cpu" ));
192- pm.addPass (mlir::createReconcileUnrealizedCastsPass ());
193- pm.addPass (mlir::createCanonicalizerPass ());
194- pm.addPass (mlir::createCSEPass ());
241+ pm.addPass (mlir::createConvertComplexToStandardPass ());
242+
243+ AddGenericLoweringPasses (pm);
195244}
196245
197246static int GetLlvmFunctionDefCount (mlir::ModuleOp m) {
@@ -223,18 +272,31 @@ FusionCompiler::FusionCompiler(mlir::MLIRContext* context, Options options,
223272 CompilationHooks hooks)
224273 : options_(std::move(options)),
225274 hooks_ (std::move(hooks)),
226- optimization_pass_manager_(
275+ scalar_optimization_pass_manager_(
276+ mlir::PassManager::on<mlir::ModuleOp>(context)),
277+ tiled_optimization_pass_manager_(
278+ mlir::PassManager::on<mlir::ModuleOp>(context)),
279+ scalar_lowering_pass_manager_(
227280 mlir::PassManager::on<mlir::ModuleOp>(context)),
228- lowering_pass_manager_(mlir::PassManager::on<mlir::ModuleOp>(context)) {
229- emitters::RegisterOptimizationPasses (optimization_pass_manager_);
230- AddLoopTransformationPasses (optimization_pass_manager_,
281+ tiled_lowering_pass_manager_(
282+ mlir::PassManager::on<mlir::ModuleOp>(context)) {
283+ // Scalar passes.
284+ AddScalarOptimizationPasses (scalar_optimization_pass_manager_,
231285 options_.vector_width );
232- optimization_pass_manager_.addInstrumentation (
233- std::make_unique<TraceInstrumentation>());
286+ AddScalarLoweringPasses (scalar_lowering_pass_manager_, options_.vector_width ,
287+ options_.fast_min_max );
288+
289+ // Tiled passes.
290+ AddTiledOptimizationPasses (tiled_optimization_pass_manager_);
291+ AddTiledLoweringPasses (tiled_lowering_pass_manager_);
234292
235- AddLoweringPasses (lowering_pass_manager_, options_.vector_width ,
236- options_.fast_min_max );
237- lowering_pass_manager_.addInstrumentation (
293+ scalar_optimization_pass_manager_.addInstrumentation (
294+ std::make_unique<TraceInstrumentation>());
295+ scalar_lowering_pass_manager_.addInstrumentation (
296+ std::make_unique<TraceInstrumentation>());
297+ tiled_optimization_pass_manager_.addInstrumentation (
298+ std::make_unique<TraceInstrumentation>());
299+ tiled_lowering_pass_manager_.addInstrumentation (
238300 std::make_unique<TraceInstrumentation>());
239301}
240302
@@ -252,6 +314,14 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
252314 });
253315 return count;
254316 };
317+
318+ bool is_tiled = !mlir_module.getBody ()->getOps <xtile::EntryFuncOp>().empty ();
319+ mlir::PassManager& optimization_pm = is_tiled
320+ ? tiled_optimization_pass_manager_
321+ : scalar_optimization_pass_manager_;
322+ mlir::PassManager& lowering_pm =
323+ is_tiled ? tiled_lowering_pass_manager_ : scalar_lowering_pass_manager_;
324+
255325 VLOG (1 ) << " Compiling MLIR module: " << module_name << " , with "
256326 << get_module_op_count () << " operations." ;
257327 XLA_SCOPED_LOGGING_TIMER_LEVEL (
@@ -266,15 +336,15 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
266336 if (hooks_.pre_optimization ) {
267337 hooks_.pre_optimization (mlir_module);
268338 }
269- TF_RETURN_IF_ERROR (RunPassPipeline (mlir_module, optimization_pass_manager_ ,
270- nullptr , options_.verification_level ));
339+ TF_RETURN_IF_ERROR (RunPassPipeline (mlir_module, optimization_pm, nullptr ,
340+ options_.verification_level ));
271341
272342 if (hooks_.post_optimization ) {
273343 hooks_.post_optimization (mlir_module);
274344 }
275345
276- TF_RETURN_IF_ERROR (RunPassPipeline (mlir_module, lowering_pass_manager_ ,
277- nullptr , options_.verification_level ));
346+ TF_RETURN_IF_ERROR (RunPassPipeline (mlir_module, lowering_pm, nullptr ,
347+ options_.verification_level ));
278348
279349 if (hooks_.post_lowering ) {
280350 hooks_.post_lowering (mlir_module);
@@ -347,14 +417,17 @@ std::unique_ptr<mlir::MLIRContext> FusionCompiler::CreateContext() {
347417 xla::cpu::XlaCpuDialect, mlir::mhlo::MhloDialect,
348418 mlir::scf::SCFDialect, mlir::LLVM::LLVMDialect,
349419 mlir::tensor::TensorDialect, mlir::vector::VectorDialect,
350- xla::XlaDialect>();
420+ xla::XlaDialect, xla::xtile::XTileDialect >();
351421
352422 mlir::DialectRegistry registry;
353423 mlir::LLVM::registerInlinerInterface (registry);
354424 mlir::func::registerInlinerExtension (registry);
355425 mlir::registerLLVMDialectTranslation (registry);
356426 mlir::registerBuiltinDialectTranslation (registry);
357427 mlir::registerConvertMathToLLVMInterface (registry);
428+ mlir::registerConvertMemRefToLLVMInterface (registry);
429+ mlir::ub::registerConvertUBToLLVMInterface (registry);
430+ mlir::vector::registerConvertVectorToLLVMInterface (registry);
358431 context->appendDialectRegistry (registry);
359432
360433 return context;
0 commit comments