Skip to content

Commit 1256ceb

Browse files
apaszkejax authors
authored andcommitted
[Mosaic GPU] Rearrange the pass pipeline (again)
PiperOrigin-RevId: 642256145
1 parent 3345952 commit 1256ceb

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

jaxlib/mosaic/gpu/custom_call.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
105105
mlir::registerConvertFuncToLLVMPass();
106106
mlir::registerConvertAffineToStandard();
107107
mlir::registerReconcileUnrealizedCasts();
108-
mlir::registerGpuToLLVMConversionPass();
109108
// TODO(apaszke): Only register the passes we actually use.
110109
mlir::memref::registerMemRefPasses();
110+
mlir::registerConvertToLLVMPass();
111111
mlir::registerGPUPasses();
112112
mosaic::gpu::registerGpuLaunchLoweringPass();
113113
mosaic::gpu::registerConvertGpuToLLVMPass();
@@ -140,11 +140,12 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
140140
convert-math-to-llvm{approximate-log1p=true},
141141
canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},
142142
cse,
143-
reconcile-unrealized-casts,)" +
143+
)" +
144144
(target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering,"
145145
: "") +
146146
R"(
147-
convert-func-to-llvm{index-bitwidth=0 use-bare-ptr-memref-call-conv=false}
147+
convert-to-llvm,
148+
reconcile-unrealized-casts
148149
)
149150
)");
150151
}
@@ -170,9 +171,9 @@ void InitContext(mlir::MLIRContext* context) {
170171
mlir::registerConvertFuncToLLVMInterface(registry);
171172
mlir::index::registerConvertIndexToLLVMInterface(registry);
172173
mlir::cf::registerConvertControlFlowToLLVMInterface(registry);
173-
mlir::ub::registerConvertUBToLLVMInterface(registry); // Arith needs this
174+
mlir::ub::registerConvertUBToLLVMInterface(registry);
174175
mlir::arith::registerConvertArithToLLVMInterface(registry);
175-
mlir::registerFinalizeMemRefToLLVMConversionPass();
176+
mlir::registerConvertMemRefToLLVMInterface(registry);
176177
mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
177178
mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry);
178179
mlir::registerBuiltinDialectTranslation(registry);

0 commit comments

Comments
 (0)