Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions SPIRV/GlslangToSpv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3159,6 +3159,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
std::vector<spv::Builder::AccessChain> complexLvalues; // for holding swizzling l-values too complex for
// SPIR-V, for an out parameter
std::vector<spv::Id> temporaryLvalues; // temporaries to pass, as proxies for complexLValues
spv::Builder::AccessChain tensorReadResultLValue = {};
tensorReadResultLValue.base = spv::NoResult; // deferred tensorReadARM out-arg store target

auto resultType = [&invertedType, &node, this](){
if (invertedType != spv::NoType) {
Expand Down Expand Up @@ -4170,7 +4172,17 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
builder.accessChainGetInferredType(), "swizzleTemp"));
operands.push_back(temporaryLvalues.back());
} else {
operands.push_back(builder.accessChainGetLValue());
if (node->getOp() == glslang::EOpTensorReadARM && arg == 2) {
// tensorReadARM stores the result after emitting the op, so keep the
// original l-value access chain and avoid materializing a transient
// pointer that may not preserve descriptor-heap indexing.
tensorReadResultLValue = builder.getAccessChain();
// Keep the operand slot so optional tensor operands keep their
// existing indices in the later lowering logic.
operands.push_back(spv::NoResult);
} else {
operands.push_back(builder.accessChainGetLValue());
}
}
lvalueCoherentFlags = builder.getAccessChain().coherentFlags;
lvalueCoherentFlags |= TranslateCoherent(glslangOperands[arg]->getAsTyped()->getType());
Expand Down Expand Up @@ -4669,7 +4681,9 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
spv::Id retType = convertGlslangToSpvType(resArgType);
result = builder.createOp(spv::Op::OpTensorReadARM, retType, idImmOps);
// Store the result to the result argument.
builder.createStore(result, operands[2]);
assert(tensorReadResultLValue.base != spv::NoResult);
builder.setAccessChain(tensorReadResultLValue);
accessChainStore(resArgType, result);
}
} else if (node->getOp() == glslang::EOpTensorSizeARM) {
// Expected operands are (tensor, dimension)
Expand Down
66 changes: 66 additions & 0 deletions Test/baseResults/spv.tensorARM.descriptorHeap.comp.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
spv.tensorARM.descriptorHeap.comp
// Module Version 10000
// Generated by (magic number): 8000b
// Id's are bound by 32

Capability Shader
Capability TensorsARM
Capability UntypedPointersKHR
Capability DescriptorHeapEXT
Extension "SPV_ARM_tensors"
Extension "SPV_EXT_descriptor_heap"
Extension "SPV_KHR_storage_buffer_storage_class"
Extension "SPV_KHR_untyped_pointers"
1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450
EntryPoint GLCompute 4 "main"
ExecutionMode 4 LocalSize 1 1 1
Source GLSL 460
SourceExtension "GL_ARM_tensors"
SourceExtension "GL_EXT_descriptor_heap"
SourceExtension "GL_EXT_shader_explicit_arithmetic_types"
Name 4 "main"
Name 11 "t"
Name 19 "resource_heap"
Name 21 "O"
MemberName 21(O) 0 "out_data"
Decorate 11(t) Binding 0
Decorate 11(t) DescriptorSet 0
Decorate 19(resource_heap) BuiltIn ResourceHeapEXT
Decorate 21(O) Block
MemberDecorate 21(O) 0 Offset 0
DecorateId 29 DecorationArrayStrideIdEXT 28
2: TypeVoid
3: TypeFunction 2
6: TypeInt 32 1
7: TypeInt 32 0
8: 7(int) Constant 4
9: TypeTensorARM 6(int) 8
10: TypePointer UniformConstant 9
11(t): 10(ptr) Variable UniformConstant
13: TypeArray 7(int) 8
14: 7(int) Constant 1
15: 7(int) Constant 2
16: 7(int) Constant 3
17: 13 ConstantComposite 14 15 16 8
18: TypeUntypedPointerKHR UniformConstant
19(resource_heap): 18(ptr) UntypedVariableKHR UniformConstant
20: 6(int) Constant 1
21(O): TypeStruct 6(int)
22: 6(int) Constant 0
23: 7(int) Constant 0
25: TypeUntypedPointerKHR StorageBuffer
27: TypeBufferEXT StorageBuffer
28: 6(int) ConstantSizeOfEXT 27
29: TypeRuntimeArray 27
4(main): 2 Function None 3
5: Label
12: 9 Load 11(t)
24: 6(int) TensorReadARM 12 17
26: 18(ptr) UntypedAccessChainKHR 29 19(resource_heap) 20
30: 25(ptr) BufferPointerEXT 26
31: 25(ptr) UntypedAccessChainKHR 21(O) 30 22
Store 31 24
Return
FunctionEnd

8 changes: 4 additions & 4 deletions Test/baseResults/spv.tensorARM.read.comp.out
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ spv.tensorARM.read.comp
32: 31(ptr) Variable StorageBuffer
33: 6(int) Constant 0
34: 6(int) Constant 1
35: TypePointer StorageBuffer 6(int)
36: TypePointer StorageBuffer 6(int)
4(main): 2 Function None 3
5: Label
19(one): 18(ptr) Variable Function
Expand All @@ -64,8 +64,8 @@ spv.tensorARM.read.comp
27: 23 TensorReadARM 22 17
Store 25(two) 27
28: 9 Load 11(t)
36: 35(ptr) AccessChain 32 33 34
37: 6(int) TensorReadARM 28 17
Store 36 37
35: 6(int) TensorReadARM 28 17
37: 36(ptr) AccessChain 32 33 34
Store 37 35
Return
FunctionEnd
16 changes: 16 additions & 0 deletions Test/spv.tensorARM.descriptorHeap.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#version 460 core
#extension GL_ARM_tensors : enable
#extension GL_EXT_descriptor_heap : require
#extension GL_EXT_shader_explicit_arithmetic_types : enable

layout(set = 0, binding = 0) uniform tensorARM<int32_t, 4> t;

layout(descriptor_heap) buffer O {
int32_t out_data;
} ssbo[];

void main()
{
tensorReadARM(t, uint[](1, 2, 3, 4), ssbo[1].out_data);
}

Loading