Skip to content

Commit 1cbd14c

Browse files
committed
add attach target pass
1 parent b7129d8 commit 1cbd14c

File tree

8 files changed

+161
-2
lines changed

8 files changed

+161
-2
lines changed

include/gc/Transforms/Passes.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,41 @@ def GpuTilingAndFusion : Pass<"gpu-tiling", "func::FuncOp"> {
150150
"The maximum workgroup size.">
151151
];
152152
}
153+
154+
def GpuXeVMAttachTarget: Pass<"xevm-attach-target", ""> {
155+
let summary = "Attaches a XeVM target attribute to a GPU Module.";
156+
let description = [{
157+
This pass searches for all GPU Modules in the immediate regions and attaches
158+
a XeVM target if the module matches the name specified by the `module` argument.
159+
160+
Example:
161+
```
162+
// File: in.mlir:
163+
gpu.module @xevm_module_1 {...}
164+
gpu.module @xevm_module_2 {...}
165+
gpu.module @xevm_module_1 {...}
166+
// mlir-opt --xevm-attach-target="module=xevm.* chip=pvc" in.mlir
167+
gpu.module @xevm_module_1 {...}
168+
gpu.module @xevm_module_2 {...}
169+
gpu.module @xevm_module_1 [#xevm.target<chip = "pvc">] {...}
170+
```
171+
}];
172+
let options = [
173+
Option<"moduleMatcher", "module", "std::string",
174+
/*default=*/ [{""}],
175+
"Regex used to identify the modules to attach the target to.">,
176+
Option<"triple", "triple", "std::string",
177+
/*default=*/ "\"spirv64-unknown-unknown\"",
178+
"Target triple.">,
179+
Option<"chip", "chip", "std::string",
180+
/*default=*/"\"pvc\"",
181+
"Target chip.">,
182+
Option<"optLevel", "O", "unsigned",
183+
/*default=*/"2",
184+
"Optimization level.">
185+
];
186+
}
187+
153188
#endif // GC_USE_IMEX
154189

155190
def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",

lib/gc/Target/LLVM/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ gc_add_mlir_dialect_library(MLIRXeVMTarget
66
ADDITIONAL_HEADER_DIRS
77
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
88
${PROJECT_SOURCE_DIR}/include/gc/Dialect/LLVMIR
9+
10+
LINK_COMPONENTS
11+
SPIRVCodeGen
912

1013
LINK_LIBS PUBLIC
1114
MLIRIR

lib/gc/Transforms/GPU/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ gc_add_mlir_library(GcGpuPasses
1717
GpuToGpuOcl.cpp
1818
LinalgToXeGPU.cpp
1919
Pipeline.cpp
20+
XeVMAttachTarget.cpp
2021

2122
DEPENDS
2223
GraphCompilerPassIncGen
@@ -31,6 +32,7 @@ gc_add_mlir_library(GcGpuPasses
3132
MLIRMathToSPIRV
3233
MLIRControlFlowToSPIRV
3334
MLIRMemRefTransforms
35+
MLIRXeVMToLLVMIRTranslation
3436
GcInterface
3537
GcUtilsIR
3638
${IMEX_LIBS}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===-- XeVMAttachTarget.cpp - DESC -----------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the `GpuXeVMAttachTarget` pass, attaching `#xevm.target`
10+
// attributes to GPU modules.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "gc/Dialect/LLVMIR/XeVMDialect.h"
15+
16+
#include "gc/Target/LLVM/XeVM/Target.h"
17+
#include "gc/Transforms/Passes.h"
18+
19+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
20+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
21+
#include "mlir/IR/Builders.h"
22+
#include "mlir/Pass/Pass.h"
23+
#include "llvm/Support/Regex.h"
24+
#include <iostream>
25+
26+
namespace mlir {
27+
namespace gc {
28+
#define GEN_PASS_DEF_GPUXEVMATTACHTARGET
29+
#include "gc/Transforms/Passes.h.inc"
30+
} // namespace gc
31+
} // namespace mlir
32+
33+
using namespace mlir::xevm;
34+
using namespace mlir;
35+
36+
namespace {
37+
struct XeVMAttachTarget
38+
: public gc::impl::GpuXeVMAttachTargetBase<XeVMAttachTarget> {
39+
using Base::Base;
40+
41+
DictionaryAttr getFlags(OpBuilder &builder) const;
42+
43+
void runOnOperation() override;
44+
45+
void getDependentDialects(DialectRegistry &registry) const override {
46+
registry.insert<xevm::XeVMDialect>();
47+
}
48+
};
49+
} // namespace
50+
51+
DictionaryAttr XeVMAttachTarget::getFlags(OpBuilder &builder) const {
52+
UnitAttr unitAttr = builder.getUnitAttr();
53+
SmallVector<NamedAttribute, 2> flags;
54+
auto addFlag = [&](StringRef flag) {
55+
flags.push_back(builder.getNamedAttr(flag, unitAttr));
56+
};
57+
if (!flags.empty())
58+
return builder.getDictionaryAttr(flags);
59+
return nullptr;
60+
}
61+
62+
void XeVMAttachTarget::runOnOperation() {
63+
OpBuilder builder(&getContext());
64+
auto target = builder.getAttr<XeVMTargetAttr>(optLevel, triple, chip);
65+
llvm::Regex matcher(moduleMatcher);
66+
for (Region &region : getOperation()->getRegions())
67+
for (Block &block : region.getBlocks())
68+
for (auto module : block.getOps<gpu::GPUModuleOp>()) {
69+
// Check if the name of the module matches.
70+
if (!moduleMatcher.empty() && !matcher.match(module.getName()))
71+
continue;
72+
// Create the target array.
73+
SmallVector<Attribute> targets;
74+
if (std::optional<ArrayAttr> attrs = module.getTargets())
75+
targets.append(attrs->getValue().begin(), attrs->getValue().end());
76+
targets.push_back(target);
77+
// Remove any duplicate targets.
78+
targets.erase(llvm::unique(targets), targets.end());
79+
// Update the target attribute array.
80+
module.setTargetsAttr(builder.getArrayAttr(targets));
81+
}
82+
}

src/gc-opt/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ if(GC_DEV_LINK_LLVM_DYLIB)
2929
else()
3030
set(MLIR_LINK_COMPONENTS
3131
MLIROptLib
32+
MLIRBuiltinToLLVMIRTranslation
33+
MLIRLLVMDialect
34+
MLIRLLVMToLLVMIRTranslation
35+
MLIRToLLVMIRTranslationRegistration
3236
)
3337
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
3438
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
@@ -38,9 +42,12 @@ add_llvm_executable(gc-opt gc-opt.cpp)
3842
llvm_update_compile_flags(gc-opt)
3943
mlir_check_all_link_libraries(gc-opt)
4044

45+
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
46+
4147
target_link_libraries(gc-opt PUBLIC GcInterface)
4248
target_link_libraries(gc-opt PRIVATE
4349
${dialect_libs}
50+
${extension_libs}
4451
${conversion_libs}
4552
${MLIR_LINK_COMPONENTS}
4653
GcPasses

src/gc-opt/gc-opt.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@
2424
#ifdef GC_HAS_ONEDNN_DIALECT
2525
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
2626
#endif
27+
#include "gc/Target/LLVM/XeVM/Target.h"
28+
#include "gc/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
2729
#include "gc/Transforms/Microkernel/MicrokernelPasses.h"
2830
#include "gc/Transforms/Passes.h"
2931
#include "mlir/InitAllDialects.h"
32+
#include "mlir/InitAllExtensions.h"
3033
#include "mlir/InitAllPasses.h"
34+
#include "mlir/Target/LLVMIR/Dialect/All.h"
3135
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
3236

3337
#ifdef GC_USE_IMEX
@@ -66,11 +70,15 @@ int main(int argc, char *argv[]) {
6670
registry.insert<mlir::cpuruntime::CPURuntimeDialect>();
6771
registry.insert<mlir::linalgx::LinalgxDialect>();
6872
registry.insert<mlir::microkernel::MicrokernelDialect>();
69-
registry.insert<mlir::xevm::XeVMDialect>();
7073
mlir::registerAllDialects(registry);
7174
#ifdef GC_USE_IMEX
72-
registry.insert<::imex::xetile::XeTileDialect, ::imex::gpux::GPUXDialect>();
75+
registry.insert<::imex::xetile::XeTileDialect, ::imex::gpux::GPUXDialect,
76+
mlir::xevm::XeVMDialect>();
77+
mlir::registerXeVMDialectTranslation(registry);
78+
mlir::xevm::registerXeVMTargetInterfaceExternalModels(registry);
7379
#endif
80+
mlir::registerAllExtensions(registry);
81+
mlir::registerAllToLLVMIRTranslations(registry);
7482
mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry);
7583
return mlir::asMainReturnCode(mlir::MlirOptMain(
7684
argc, argv, "Graph Compiler modular optimizer driver\n", registry));
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: gc-opt %s --gpu-to-llvm --convert-gpu-to-llvm-spv --gpu-module-to-binary | FileCheck %s
2+
3+
module attributes {gpu.container_module} {
4+
// CHECK-LABEL:gpu.binary @entry_kernel
5+
// CHECK:[#gpu.object<#xevm.target,
6+
gpu.module @entry_kernel [#xevm.target] {
7+
gpu.func @entry_kernel(%arg0: index) kernel attributes {} {
8+
gpu.return
9+
}
10+
}
11+
}
12+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: gc-opt %s --xevm-attach-target | FileCheck %s
2+
module attributes {gpu.container_module} {
3+
//CHECK:gpu.module @entry_kernel [#xevm.target]
4+
gpu.module @entry_kernel {
5+
gpu.func @entry_kernel(%arg0: index) kernel attributes {} {
6+
gpu.return
7+
}
8+
}
9+
}
10+

0 commit comments

Comments
 (0)