Skip to content

Commit b937eae

Browse files
authored
Implement GL_EXT_long_vector (KhronosGroup#4132)
* Implement GL_EXT_long_vector
1 parent 1399733 commit b937eae

File tree

67 files changed

+13714
-4921
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+13714
-4921
lines changed

SPIRV/GLSL.ext.EXT.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ static const char* const E_SPV_EXT_mesh_shader = "SPV_EXT_mesh_shader";
4444
static const char* const E_SPV_EXT_float8 = "SPV_EXT_float8";
4545
static const char* const E_SPV_EXT_shader_64bit_indexing = "SPV_EXT_shader_64bit_indexing";
4646
static const char* const E_SPV_EXT_shader_invocation_reorder = "SPV_EXT_shader_invocation_reorder";
47+
static const char* const E_SPV_EXT_long_vector = "SPV_EXT_long_vector";
4748

4849
#endif // #ifndef GLSLextEXT_H

SPIRV/GlslangToSpv.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2698,9 +2698,13 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
26982698
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
26992699
length = builder.createCooperativeMatrixLengthNV(typeId);
27002700
}
2701-
} else if (node->getOperand()->getType().isCoopVecNV()) {
2701+
} else if (node->getOperand()->getType().isCoopVecOrLongVector()) {
27022702
spv::Id typeId = convertGlslangToSpvType(node->getOperand()->getType());
2703-
length = builder.getCooperativeVectorNumComponents(typeId);
2703+
if (builder.isCooperativeVectorType(typeId)) {
2704+
length = builder.getCooperativeVectorNumComponents(typeId);
2705+
} else {
2706+
length = builder.makeIntConstant(builder.getNumTypeConstituents(typeId));
2707+
}
27042708
} else {
27052709
glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
27062710
block->traverse(this);
@@ -3351,7 +3355,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
33513355
builder.addCapability(spv::Capability::CooperativeMatrixConversionsNV);
33523356
builder.addExtension(spv::E_SPV_NV_cooperative_matrix2);
33533357
constructed = builder.createCooperativeMatrixConversion(resultType(), arguments[0]);
3354-
} else if (node->getOp() == glslang::EOpConstructCooperativeVectorNV &&
3358+
} else if (node->getType().isCoopVecOrLongVector() &&
33553359
arguments.size() == 1 &&
33563360
builder.getTypeId(arguments[0]) == resultType()) {
33573361
constructed = arguments[0];
@@ -3361,7 +3365,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
33613365
node->getType().isArray() ||
33623366
// Handle constructing coopvec from one component here, to avoid the component
33633367
// getting smeared
3364-
(node->getOp() == glslang::EOpConstructCooperativeVectorNV && arguments.size() == 1 && builder.isScalar(arguments[0]))) {
3368+
(node->getType().hasSpecConstantVectorComponents() && arguments.size() == 1 && builder.isScalar(arguments[0]))) {
33653369
std::vector<spv::Id> constituents;
33663370
for (int c = 0; c < (int)arguments.size(); ++c)
33673371
constituents.push_back(arguments[c]);
@@ -3423,7 +3427,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
34233427
{
34243428
// for scalar dot product, use multiply
34253429
glslang::TIntermSequence& glslangOperands = node->getSequence();
3426-
if (glslangOperands[0]->getAsTyped()->getVectorSize() == 1)
3430+
if (!glslangOperands[0]->getAsTyped()->getType().isLongVector() &&
3431+
glslangOperands[0]->getAsTyped()->getVectorSize() == 1)
34273432
binOp = glslang::EOpMul;
34283433
break;
34293434
}
@@ -5647,6 +5652,24 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
56475652
spvType = builder.makeCooperativeVectorTypeNV(spvType, components);
56485653
}
56495654

5655+
if (type.isLongVector()) {
5656+
builder.addCapability(spv::Capability::LongVectorEXT);
5657+
builder.addExtension(spv::E_SPV_EXT_long_vector);
5658+
5659+
if (type.getBasicType() == glslang::EbtFloat16)
5660+
builder.addCapability(spv::Capability::Float16);
5661+
if (type.getBasicType() == glslang::EbtUint8 || type.getBasicType() == glslang::EbtInt8) {
5662+
builder.addCapability(spv::Capability::Int8);
5663+
}
5664+
5665+
if (type.hasSpecConstantVectorComponents()) {
5666+
spv::Id components = makeArraySizeId(*type.getTypeParameters()->arraySizes, 0);
5667+
spvType = builder.makeCooperativeVectorTypeNV(spvType, components);
5668+
} else {
5669+
spvType = builder.makeVectorType(spvType, type.getTypeParameters()->arraySizes->getDimSize(0));
5670+
}
5671+
}
5672+
56505673
if (type.isArray()) {
56515674
int stride = 0; // keep this 0 unless doing an explicit layout; 0 will mean no decoration, no stride
56525675

@@ -9843,9 +9866,12 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::
98439866
builder.addExtension(spv::E_SPV_AMD_gpu_shader_int16);
98449867
if (builder.getNumComponents(operands[0]) == 1)
98459868
frexpIntType = builder.makeIntegerType(width, true);
9869+
else if (builder.isCooperativeVector(operands[0]))
9870+
frexpIntType = builder.makeCooperativeVectorTypeNV(builder.makeIntegerType(width, true),
9871+
builder.getCooperativeVectorNumComponents(builder.getTypeId(operands[0])));
98469872
else
98479873
frexpIntType = builder.makeVectorType(builder.makeIntegerType(width, true),
9848-
builder.getNumComponents(operands[0]));
9874+
builder.getNumComponents(operands[0]));
98499875
typeId = builder.makeStructResultType(typeId0, frexpIntType);
98509876
consumedOperands = 1;
98519877
}
@@ -11149,8 +11175,8 @@ spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstUnionArray(const glsla
1114911175
glslang::TVector<glslang::TTypeLoc>::const_iterator iter;
1115011176
for (iter = glslangType.getStruct()->begin(); iter != glslangType.getStruct()->end(); ++iter)
1115111177
spvConsts.push_back(createSpvConstantFromConstUnionArray(*iter->type, consts, nextConst, false));
11152-
} else if (glslangType.getVectorSize() > 1 || glslangType.isCoopVecNV()) {
11153-
unsigned int numComponents = glslangType.isCoopVecNV() ? glslangType.getTypeParameters()->arraySizes->getDimSize(0) : glslangType.getVectorSize();
11178+
} else if (glslangType.getVectorSize() > 1 || glslangType.isCoopVecOrLongVector()) {
11179+
unsigned int numComponents = glslangType.isCoopVecOrLongVector() ? glslangType.getTypeParameters()->arraySizes->getDimSize(0) : glslangType.getVectorSize();
1115411180
for (unsigned int i = 0; i < numComponents; ++i) {
1115511181
bool zero = nextConst >= consts.size();
1115611182
switch (glslangType.getBasicType()) {

SPIRV/SpvBuilder.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3344,12 +3344,21 @@ Id Builder::createLvalueSwizzle(Id typeId, Id target, Id source, const std::vect
33443344
// Comments in header
33453345
void Builder::promoteScalar(Decoration precision, Id& left, Id& right)
33463346
{
3347-
int direction = getNumComponents(right) - getNumComponents(left);
3347+
// choose direction of promotion (+1 for left to right, -1 for right to left)
3348+
int direction = !isScalar(right) - !isScalar(left);
3349+
3350+
auto const &makeVec = [&](Id component, Id other) {
3351+
if (isCooperativeVector(other)) {
3352+
return makeCooperativeVectorTypeNV(getTypeId(component), getCooperativeVectorNumComponents(getTypeId(other)));
3353+
} else {
3354+
return makeVectorType(getTypeId(component), getNumComponents(other));
3355+
}
3356+
};
33483357

33493358
if (direction > 0)
3350-
left = smearScalar(precision, left, makeVectorType(getTypeId(left), getNumComponents(right)));
3359+
left = smearScalar(precision, left, makeVec(left, right));
33513360
else if (direction < 0)
3352-
right = smearScalar(precision, right, makeVectorType(getTypeId(right), getNumComponents(left)));
3361+
right = smearScalar(precision, right, makeVec(right, left));
33533362

33543363
return;
33553364
}
@@ -3361,7 +3370,7 @@ Id Builder::smearScalar(Decoration precision, Id scalar, Id vectorType)
33613370
assert(getTypeId(scalar) == getScalarTypeId(vectorType));
33623371

33633372
int numComponents = getNumTypeComponents(vectorType);
3364-
if (numComponents == 1 && !isCooperativeVectorType(vectorType))
3373+
if (numComponents == 1 && !isCooperativeVectorType(vectorType) && !isVectorType(vectorType))
33653374
return scalar;
33663375

33673376
Instruction* smear = nullptr;
@@ -3773,7 +3782,7 @@ Id Builder::createCompositeConstruct(Id typeId, const std::vector<Id>& constitue
37733782
{
37743783
assert(isAggregateType(typeId) || (getNumTypeConstituents(typeId) > 1 &&
37753784
getNumTypeConstituents(typeId) == constituents.size()) ||
3776-
(isCooperativeVectorType(typeId) && constituents.size() == 1));
3785+
((isCooperativeVectorType(typeId) || isVectorType(typeId)) && constituents.size() == 1));
37773786

37783787
if (generatingOpCodeForSpecConst) {
37793788
// Sometime, even in spec-constant-op mode, the constant composite to be
@@ -3862,9 +3871,16 @@ Id Builder::createConstructor(Decoration precision, const std::vector<Id>& sourc
38623871
return smearScalar(precision, sources[0], resultTypeId);
38633872

38643873
// Special case: 2 vectors of equal size
3865-
if (sources.size() == 1 && isVector(sources[0]) && numTargetComponents == getNumComponents(sources[0])) {
3866-
assert(resultTypeId == getTypeId(sources[0]));
3867-
return sources[0];
3874+
if (sources.size() == 1 &&
3875+
(isVector(sources[0]) || isCooperativeVector(sources[0])) &&
3876+
numTargetComponents == getNumComponents(sources[0])) {
3877+
if (isCooperativeVector(sources[0]) != isCooperativeVectorType(resultTypeId)) {
3878+
assert(isVector(sources[0]) != isVectorType(resultTypeId));
3879+
return createUnaryOp(spv::Op::OpBitcast, resultTypeId, sources[0]);
3880+
} else {
3881+
assert(resultTypeId == getTypeId(sources[0]));
3882+
return sources[0];
3883+
}
38683884
}
38693885

38703886
// accumulate the arguments for OpCompositeConstruct
@@ -3873,7 +3889,7 @@ Id Builder::createConstructor(Decoration precision, const std::vector<Id>& sourc
38733889

38743890
// lambda to store the result of visiting an argument component
38753891
const auto latchResult = [&](Id comp) {
3876-
if (numTargetComponents > 1)
3892+
if (numTargetComponents > 1 || isVectorType(resultTypeId))
38773893
constituents.push_back(comp);
38783894
else
38793895
result = comp;
@@ -4372,7 +4388,7 @@ Id Builder::accessChainLoad(Decoration precision, Decoration l_nonUniform,
43724388
if (constant) {
43734389
id = createCompositeExtract(accessChain.base, swizzleBase, indexes);
43744390
setPrecision(id, precision);
4375-
} else if (isCooperativeVector(accessChain.base)) {
4391+
} else if (isVector(accessChain.base) || isCooperativeVector(accessChain.base)) {
43764392
assert(accessChain.indexChain.size() == 1);
43774393
id = createVectorExtractDynamic(accessChain.base, resultType, accessChain.indexChain[0]);
43784394
} else {

SPIRV/SpvTools.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ void SpirvToolsValidate(const glslang::TIntermediate& intermediate, std::vector<
166166
spvValidatorOptionsSetScalarBlockLayout(options, intermediate.usingScalarBlockLayout());
167167
spvValidatorOptionsSetWorkgroupScalarBlockLayout(options, intermediate.usingScalarBlockLayout());
168168
spvValidatorOptionsSetAllowOffsetTextureOperand(options, intermediate.usingTextureOffsetNonConst());
169+
spvValidatorOptionsSetAllowVulkan32BitBitwise(options, true);
169170
spvValidateWithOptions(context, options, &binary, &diagnostic);
170171

171172
// report

SPIRV/doc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,7 @@ const char* CapabilityString(int info)
11411141
case (int)Capability::Float8CooperativeMatrixEXT: return "Float8CooperativeMatrixEXT";
11421142

11431143
case (int)Capability::Shader64BitIndexingEXT: return "CapabilityShader64BitIndexingEXT";
1144+
case (int)Capability::LongVectorEXT: return "LongVectorEXT";
11441145

11451146
default: return "Bad";
11461147
}

SPIRV/spirv.hpp11

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,7 @@ enum class Capability : unsigned {
12001200
RawAccessChainsNV = 5414,
12011201
RayTracingSpheresGeometryNV = 5418,
12021202
RayTracingLinearSweptSpheresGeometryNV = 5419,
1203+
LongVectorEXT = 5425,
12031204
Shader64BitIndexingEXT = 5426,
12041205
CooperativeMatrixReductionsNV = 5430,
12051206
CooperativeMatrixConversionsNV = 5431,
@@ -2065,7 +2066,8 @@ enum class Op : unsigned {
20652066
OpReorderThreadWithHintNV = 5280,
20662067
OpTypeHitObjectNV = 5281,
20672068
OpImageSampleFootprintNV = 5283,
2068-
OpTypeCooperativeVectorNV = 5288,
2069+
OpTypeCooperativeVectorNV = 5288,
2070+
OpTypeVectorIdEXT = 5288,
20692071
OpCooperativeVectorMatrixMulNV = 5289,
20702072
OpCooperativeVectorOuterProductAccumulateNV = 5290,
20712073
OpCooperativeVectorReduceSumAccumulateNV = 5291,
@@ -4158,6 +4160,7 @@ inline const char* CapabilityToString(Capability value) {
41584160
case Capability::RawAccessChainsNV: return "RawAccessChainsNV";
41594161
case Capability::RayTracingSpheresGeometryNV: return "RayTracingSpheresGeometryNV";
41604162
case Capability::RayTracingLinearSweptSpheresGeometryNV: return "RayTracingLinearSweptSpheresGeometryNV";
4163+
case Capability::LongVectorEXT: return "LongVectorEXT";
41614164
case Capability::Shader64BitIndexingEXT: return "Shader64BitIndexingEXT";
41624165
case Capability::CooperativeMatrixReductionsNV: return "CooperativeMatrixReductionsNV";
41634166
case Capability::CooperativeMatrixConversionsNV: return "CooperativeMatrixConversionsNV";

0 commit comments

Comments
 (0)