Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,11 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
for (Attribute targetAttr : targets)
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
if (auto spirvTargetEnvAttr =
dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
break;
}
}

rewriter.eraseOp(moduleOp);
Expand Down
24 changes: 23 additions & 1 deletion mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt --convert-gpu-to-spirv %s | FileCheck %s
// RUN: mlir-opt --split-input-file --convert-gpu-to-spirv %s | FileCheck %s

module attributes {gpu.container_module} {
// CHECK-LABEL: spirv.module @{{.*}} GLSL450
Expand All @@ -15,3 +15,25 @@ module attributes {gpu.container_module} {
}
}
}

// -----

module attributes {gpu.container_module} {
// CHECK-LABEL: spirv.module @{{.*}} GLSL450
// CHECK-SAME: #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>
gpu.module @kernels [
#spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>,
#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>,
#spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>] {
// CHECK: spirv.func @load_kernel
// CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%c0 = arith.constant 0 : index
// CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32
%0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32>
// CHECK: spirv.Return
gpu.return
}
}
}