|
1 | | -//===-- XeVMAttachTarget.cpp - DESC -----------------------------*- C++ -*-===// |
| 1 | +//===-- XeVMAttachTarget.cpp - Attach an XeVM target ----------------------===// |
2 | 2 | // |
3 | 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | 4 | // See https://llvm.org/LICENSE.txt for license information. |
@@ -68,21 +68,24 @@ void XeVMAttachTarget::runOnOperation() { |
68 | 68 | OpBuilder builder(&getContext()); |
69 | 69 | ArrayRef<std::string> libs(linkLibs); |
70 | 70 | SmallVector<StringRef> filesToLink(libs); |
71 | | - auto target = builder.getAttr<mlir::xevm::XeVMTargetAttr>( |
| 71 | + auto target = builder.getAttr<xevm::XeVMTargetAttr>( |
72 | 72 | optLevel, triple, chip, getFlags(builder), |
73 | 73 | filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink)); |
74 | 74 | llvm::Regex matcher(moduleMatcher); |
75 | | - // Check if the name of the module matches. |
76 | | - auto gpuModule = cast<gpu::GPUModuleOp>(getOperation()); |
77 | | - if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName())) |
78 | | - return; |
79 | | - // Create the target array. |
80 | | - SmallVector<Attribute> targets; |
81 | | - if (std::optional<ArrayAttr> attrs = gpuModule.getTargets()) |
82 | | - targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
83 | | - targets.push_back(target); |
84 | | - // Remove any duplicate targets. |
85 | | - targets.erase(llvm::unique(targets), targets.end()); |
86 | | - // Update the target attribute array. |
87 | | - gpuModule.setTargetsAttr(builder.getArrayAttr(targets)); |
| 75 | + for (Region ®ion : getOperation()->getRegions()) |
| 76 | + for (Block &block : region.getBlocks()) |
| 77 | + for (auto module : block.getOps<gpu::GPUModuleOp>()) { |
| 78 | + // Check if the name of the module matches. |
| 79 | + if (!moduleMatcher.empty() && !matcher.match(module.getName())) |
| 80 | + continue; |
| 81 | + // Create the target array. |
| 82 | + SmallVector<Attribute> targets; |
| 83 | + if (std::optional<ArrayAttr> attrs = module.getTargets()) |
| 84 | + targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
| 85 | + targets.push_back(target); |
| 86 | + // Remove any duplicate targets. |
| 87 | + targets.erase(llvm::unique(targets), targets.end()); |
| 88 | + // Update the target attribute array. |
| 89 | + module.setTargetsAttr(builder.getArrayAttr(targets)); |
| 90 | + } |
88 | 91 | } |
0 commit comments