2929namespace mlir ::gc {
3030
3131void populateGPUPipeline (OpPassManager &pm,
32- const GPUPipelineOption &pipelineOption) {
33- // Add an argument for the GPU context
34- pm.addNestedPass <func::FuncOp>(createAddContextArg ());
32+ const GPUPipelineOptions &pipelineOpts) {
33+ if (pipelineOpts.useGpuRuntime ) {
34+ // Add an argument for the GPU context
35+ pm.addNestedPass <func::FuncOp>(createAddContextArg ());
36+ }
3537
3638 pm.addNestedPass <func::FuncOp>(createIterativeTilingAndFusion ());
3739
@@ -72,10 +74,9 @@ void populateGPUPipeline(OpPassManager &pm,
7274
7375 imex::InsertGPUAllocsOptions insertGPUAllocsOption{
7476 /* clientAPI*/ " opencl" , /* inRegions*/ false ,
75- /* isUsmArgs*/ pipelineOption .isUsmArgs . getValue () };
77+ /* isUsmArgs*/ pipelineOpts .isUsmArgs };
7678 pm.addNestedPass <func::FuncOp>(
7779 imex::createInsertGPUAllocsPass (insertGPUAllocsOption));
78-
7980 pm.addPass (createGpuKernelOutliningPass ());
8081 pm.addPass (createCanonicalizerPass ());
8182 pm.addPass (imex::createSetSPIRVCapabilitiesPass ());
@@ -94,14 +95,25 @@ void populateGPUPipeline(OpPassManager &pm,
9495 pm.addNestedPass <func::FuncOp>(LLVM::createRequestCWrappersPass ());
9596 pm.addPass (imex::createSerializeSPIRVPass ());
9697 pm.addPass (createConvertVectorToSCFPass ());
98+
99+ if (!pipelineOpts.useGpuRuntime ) {
100+ pm.addPass (imex::createConvertGPUToGPUXPass ());
101+ }
102+
97103 pm.addPass (createConvertSCFToCFPass ());
98104 pm.addPass (createConvertControlFlowToLLVMPass ());
99105 pm.addPass (createConvertVectorToLLVMPass ());
100106 pm.addPass (createConvertIndexToLLVMPass ());
101107 pm.addPass (createArithToLLVMConversionPass ());
102108 pm.addPass (createConvertFuncToLLVMPass ());
103109 pm.addPass (createConvertMathToLLVMPass ());
104- pm.addPass (createGpuToGpuOcl ({pipelineOption.callFinish }));
110+
111+ if (pipelineOpts.useGpuRuntime ) {
112+ pm.addPass (createGpuToGpuOcl ({pipelineOpts.callFinish }));
113+ } else {
114+ pm.addPass (imex::createConvertGPUXToLLVMPass ());
115+ }
116+
105117 pm.addPass (createConvertIndexToLLVMPass ());
106118 pm.addPass (memref::createExpandStridedMetadataPass ());
107119 pm.addPass (createLowerAffinePass ());
@@ -110,7 +122,7 @@ void populateGPUPipeline(OpPassManager &pm,
110122}
111123
112124void registerGPUPipeline () {
113- PassPipelineRegistration<GPUPipelineOption >(
125+ PassPipelineRegistration<GPUPipelineOptions >(
114126 " gc-gpu-pipeline" , " The GPU pipeline for Graph Compiler with IMEX" ,
115127 populateGPUPipeline);
116128}
0 commit comments