Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,11 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
for (const Attribute &targetAttr : moduleOp.getTargetsAttr())
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
if (const ArrayAttr &targets = moduleOp.getTargetsAttr()) {
for (const Attribute &targetAttr : targets)
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
}

rewriter.eraseOp(moduleOp);
return success();
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {

spirv::TargetEnvAttr
GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) {
for (const Attribute &targetAttr : moduleOp.getTargetsAttr())
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
return spirvTargetEnvAttr;
if (const ArrayAttr &targets = moduleOp.getTargetsAttr()) {
for (const Attribute &targetAttr : targets)
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
return spirvTargetEnvAttr;
}

return {};
}
Expand Down