Skip to content

Commit 3ac12f1

Browse files
spirv-val: Add SPV_KHR_integer_dot_product (#6524)
We seem to just have actually never added validation for `SPV_KHR_integer_dot_product` in #4327 (and seems after a few years, no one was going to do it) This adds it
1 parent 1c69c17 commit 3ac12f1

File tree

7 files changed

+411
-49
lines changed

7 files changed

+411
-49
lines changed

Android.mk

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ SPVTOOLS_SRC_FILES := \
5353
source/val/validate_debug.cpp \
5454
source/val/validate_decorations.cpp \
5555
source/val/validate_derivatives.cpp \
56+
source/val/validate_dot_product.cpp \
5657
source/val/validate_extensions.cpp \
5758
source/val/validate_execution_limitations.cpp \
5859
source/val/validate_function.cpp \

BUILD.gn

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ static_library("spvtools_val") {
346346
"source/val/validate_debug.cpp",
347347
"source/val/validate_decorations.cpp",
348348
"source/val/validate_derivatives.cpp",
349+
"source/val/validate_dot_product.cpp",
349350
"source/val/validate_execution_limitations.cpp",
350351
"source/val/validate_extensions.cpp",
351352
"source/val/validate_function.cpp",

source/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ set(SPIRV_SOURCES
260260
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_debug.cpp
261261
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_decorations.cpp
262262
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_derivatives.cpp
263+
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_dot_product.cpp
263264
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_extensions.cpp
264265
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_execution_limitations.cpp
265266
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_function.cpp

source/val/validate.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
390390
if (auto error = AtomicsPass(*vstate, &instruction)) return error;
391391
if (auto error = PrimitivesPass(*vstate, &instruction)) return error;
392392
if (auto error = BarriersPass(*vstate, &instruction)) return error;
393+
if (auto error = DotProductPass(*vstate, &instruction)) return error;
393394
if (auto error = GroupPass(*vstate, &instruction)) return error;
394395
// Device-Side Enqueue
395396
// Pipe

source/val/validate.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst);
180180
/// Validates correctness of barrier instructions.
181181
spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst);
182182

183+
/// Validates correctness of DotProduct instructions.
184+
spv_result_t DotProductPass(ValidationState_t& _, const Instruction* inst);
185+
183186
/// Validates correctness of Group (Kernel) instructions.
184187
spv_result_t GroupPass(ValidationState_t& _, const Instruction* inst);
185188

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
// Copyright (c) 2026 LunarG Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <cstdint>
16+
17+
#include "source/val/instruction.h"
18+
#include "source/val/validate.h"
19+
#include "source/val/validate_scopes.h"
20+
#include "source/val/validation_state.h"
21+
22+
namespace spvtools {
23+
namespace val {
24+
namespace {
25+
26+
spv_result_t ValidateSameSignedDot(ValidationState_t& _,
27+
const Instruction* inst) {
28+
const uint32_t result_id = inst->type_id();
29+
if (!_.IsIntScalarType(result_id)) {
30+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
31+
<< "Result must be an int scalar type.";
32+
}
33+
34+
const spv::Op opcode = inst->opcode();
35+
const bool has_accumulator = opcode == spv::Op::OpSDotAccSat ||
36+
opcode == spv::Op::OpUDotAccSat ||
37+
opcode == spv::Op::OpSUDotAccSat;
38+
if (has_accumulator) {
39+
const uint32_t accumulator_type = _.GetOperandTypeId(inst, 4);
40+
if (accumulator_type != result_id) {
41+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
42+
<< "Result must be the same as the Accumulator type.";
43+
}
44+
}
45+
46+
if (opcode == spv::Op::OpUDot || opcode == spv::Op::OpUDotAccSat) {
47+
if (!_.IsIntScalarTypeWithSignedness(result_id, 0)) {
48+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
49+
<< "Result must be an unsigned int scalar type.";
50+
}
51+
}
52+
53+
const uint32_t vec_1_id = _.GetOperandTypeId(inst, 2);
54+
const uint32_t vec_2_id = _.GetOperandTypeId(inst, 3);
55+
56+
const bool is_vec_1_scalar = _.IsIntScalarType(vec_1_id, 32);
57+
const bool is_vec_2_scalar = _.IsIntScalarType(vec_2_id, 32);
58+
if (is_vec_1_scalar != is_vec_2_scalar) {
59+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
60+
<< "'Vector 1' and 'Vector 2' must be the same type.";
61+
} else if (is_vec_1_scalar && is_vec_2_scalar) {
62+
// If both are scalar, spec doesn't say Signedness needs to match
63+
const uint32_t vec_1_width = _.GetBitWidth(vec_1_id);
64+
const uint32_t vec_2_width = _.GetBitWidth(vec_2_id);
65+
if (vec_1_width != 32) {
66+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
67+
<< "Expected 'Vector 1' to be 32-bit when a scalar.";
68+
} else if (vec_2_width != 32) {
69+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
70+
<< "Expected 'Vector 2' to be 32-bit when a scalar.";
71+
}
72+
73+
// When packed, the result can be 8-bit
74+
const uint32_t result_width = _.GetBitWidth(result_id);
75+
if (result_width < 8) {
76+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
77+
<< "Result width (" << result_width
78+
<< ") must be greater than or equal to the packed vector width of "
79+
"8";
80+
}
81+
82+
// PackedVectorFormat4x8Bit is used when the "Vector" operand are really
83+
// scalar
84+
const uint32_t packed_operand = has_accumulator ? 6 : 5;
85+
const bool has_packed_vec_format =
86+
inst->operands().size() == packed_operand;
87+
if (!has_packed_vec_format) {
88+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
89+
<< "'Vector 1' and 'Vector 2' are a 32-bit int scalar, but no "
90+
"Packed Vector "
91+
"Format was provided.";
92+
}
93+
} else {
94+
// both should be vectors
95+
96+
if (!_.IsVectorType(vec_1_id)) {
97+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
98+
<< "Expected 'Vector 1' to be an int scalar or vector.";
99+
} else if (!_.IsVectorType(vec_2_id)) {
100+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
101+
<< "Expected 'Vector 2' to be an int scalar or vector.";
102+
}
103+
104+
const uint32_t vec_1_length = _.GetDimension(vec_1_id);
105+
const uint32_t vec_2_length = _.GetDimension(vec_2_id);
106+
// If using OpTypeVectorIdEXT with a spec constant, this can be evaluated
107+
// when spec constants are frozen
108+
if (vec_1_length != 0 && vec_2_length != 0 &&
109+
vec_1_length != vec_2_length) {
110+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
111+
<< "'Vector 1' is " << vec_1_length
112+
<< " components but 'Vector 2' is " << vec_2_length
113+
<< " components";
114+
}
115+
116+
const uint32_t vec_1_type = _.GetComponentType(vec_1_id);
117+
const uint32_t vec_2_type = _.GetComponentType(vec_2_id);
118+
if (!_.IsIntScalarType(vec_1_type)) {
119+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
120+
<< "Expected 'Vector 1' to be a vector of integers.";
121+
} else if (!_.IsIntScalarType(vec_2_type)) {
122+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
123+
<< "Expected 'Vector 2' to be a vector of integers.";
124+
}
125+
126+
const uint32_t vec_1_width = _.GetBitWidth(vec_1_type);
127+
const uint32_t vec_2_width = _.GetBitWidth(vec_2_type);
128+
if (vec_1_width != vec_2_width) {
129+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
130+
<< "'Vector 1' component is " << vec_1_width
131+
<< "-bit but 'Vector 2' component is " << vec_2_width << "-bit";
132+
}
133+
134+
const uint32_t result_width = _.GetBitWidth(result_id);
135+
if (result_width < vec_1_width) {
136+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
137+
<< "Result width (" << result_width
138+
<< ") must be greater than or equal to the vectors width ("
139+
<< vec_1_width << ").";
140+
}
141+
142+
if (opcode == spv::Op::OpUDot || opcode == spv::Op::OpUDotAccSat) {
143+
const bool vec_1_unsigned =
144+
_.IsIntScalarTypeWithSignedness(vec_1_type, 0);
145+
const bool vec_2_unsigned =
146+
_.IsIntScalarTypeWithSignedness(vec_2_type, 0);
147+
if (!vec_1_unsigned) {
148+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
149+
<< "Expected 'Vector 1' to be an vector of unsigned integers.";
150+
} else if (!vec_2_unsigned) {
151+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
152+
<< "Expected 'Vector 2' to be an vector of unsigned integers.";
153+
}
154+
} else if (opcode == spv::Op::OpSUDot || opcode == spv::Op::OpSUDotAccSat) {
155+
const bool vec_2_unsigned =
156+
_.IsIntScalarTypeWithSignedness(vec_2_type, 0);
157+
if (!vec_2_unsigned) {
158+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
159+
<< "Expected 'Vector 2' to be an vector of unsigned integers.";
160+
}
161+
}
162+
}
163+
164+
return SPV_SUCCESS;
165+
}
166+
167+
} // namespace
168+
169+
spv_result_t DotProductPass(ValidationState_t& _, const Instruction* inst) {
170+
const spv::Op opcode = inst->opcode();
171+
172+
switch (opcode) {
173+
case spv::Op::OpSDot:
174+
case spv::Op::OpUDot:
175+
case spv::Op::OpSUDot:
176+
case spv::Op::OpSDotAccSat:
177+
case spv::Op::OpUDotAccSat:
178+
case spv::Op::OpSUDotAccSat:
179+
return ValidateSameSignedDot(_, inst);
180+
default:
181+
break;
182+
}
183+
184+
return SPV_SUCCESS;
185+
}
186+
187+
} // namespace val
188+
} // namespace spvtools

0 commit comments

Comments
 (0)