Skip to content

Commit d0c6317

Browse files
committed
Add a pass option to provide the XeGPU code level.
XeGPU allows worgroup, subgroup, and workitem level programming. This options lets the pass manager know at which level the XeGPU ops belong to.
1 parent 82f34dd commit d0c6317

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,17 @@ struct GPUToNVVMPipelineOptions
6363
// Options for the gpu to xevm pipeline.
6464
struct GPUToXeVMPipelineOptions
6565
: public PassPipelineOptions<GPUToXeVMPipelineOptions> {
66+
// XeGPU op granularity selection: workgroup | subgroup | workitem
67+
PassOptions::Option<std::string> xegpuOpLevel{
68+
*this, "xegpu-op-level",
69+
llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
70+
"subgroup | workitem"),
71+
llvm::cl::init("workgroup")};
6672
// General lowering controls.
67-
PassOptions::Option<int64_t> indexBitWidth{
68-
*this, "index-bitwidth",
73+
PassOptions::Option<bool> use64bitIndex{
74+
*this, "use-64bit-index",
6975
llvm::cl::desc("Bitwidth of the index type (host & device)"),
70-
llvm::cl::init(64)};
76+
llvm::cl::init(true)};
7177
PassOptions::Option<bool> kernelBarePtrCallConv{
7278
*this, "kernel-bare-ptr-calling-convention",
7379
llvm::cl::desc("Use bare pointer calling convention for device kernels"),

mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,24 +62,30 @@ void buildCommonPassPipeline(
6262
//===----------------------------------------------------------------------===//
6363
void buildGpuPassPipeline(OpPassManager &pm,
6464
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
65-
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
66-
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
67-
pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
68-
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
69-
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
70-
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
71-
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
72-
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
73-
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
74-
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
75-
pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
76-
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
77-
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
65+
if (options.xegpuOpLevel == "workgroup") {
66+
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
67+
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
68+
pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
69+
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
70+
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
71+
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
72+
}
73+
if (options.xegpuOpLevel == "subgroup") {
74+
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
75+
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
76+
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
77+
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
78+
pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
79+
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
80+
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
81+
}
7882
pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
79-
ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
80-
gpuToLLVMSPVOptions.use64bitIndex = options.indexBitWidth;
81-
pm.addNestedPass<gpu::GPUModuleOp>(
82-
createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
83+
{
84+
ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
85+
gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;
86+
pm.addNestedPass<gpu::GPUModuleOp>(
87+
createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
88+
}
8389
pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeVMToLLVMPass());
8490
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
8591
}
@@ -104,14 +110,14 @@ void buildHostPostPipeline(OpPassManager &pm,
104110
}
105111
pm.addPass(createConvertToLLVMPass());
106112
pm.addPass(createLowerAffinePass());
113+
pm.addPass(createReconcileUnrealizedCastsPass());
107114
// gpu-module-to-binary
108115
{
109116
GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
110117
gpuToModuleBinOptions.compilationTarget = options.binaryFormat;
111118
gpuToModuleBinOptions.cmdOptions = options.cmdOptions;
112119
pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions));
113120
}
114-
pm.addPass(createReconcileUnrealizedCastsPass());
115121
}
116122
} // namespace
117123

0 commit comments

Comments
 (0)