Skip to content

Commit 08f1e75

Browse files
authored
spir-val: fix OpTensor{Read,Write}ARM for tensors with a spec constant rank (KhronosGroup#6206)
Change-Id: Iad669fb55832dc6ad75e0a59876ebd3a2501d9e7 Signed-off-by: Kevin Petit <[email protected]>
1 parent 54fc952 commit 08f1e75

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

source/val/validate_tensor.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ spv_result_t ValidateTensorRead(ValidationState_t& _, const Instruction* inst) {
8383
auto op_coord = inst->word(4);
8484
auto inst_coord = _.FindDef(op_coord);
8585
auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
86-
if (tensor_rank == 0 ||
87-
!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
86+
if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
8887
return _.diag(SPV_ERROR_INVALID_DATA, inst)
8988
<< "Expected Coordinates to be an array whose Element Type is an "
9089
"integer type and whose Length is equal to the Rank of Tensor.";
@@ -143,8 +142,7 @@ spv_result_t ValidateTensorWrite(ValidationState_t& _,
143142
auto op_coord = inst->word(2);
144143
auto inst_coord = _.FindDef(op_coord);
145144
auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
146-
if (tensor_rank == 0 ||
147-
!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
145+
if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
148146
return _.diag(SPV_ERROR_INVALID_DATA, inst)
149147
<< "Expected Coordinates to be an array whose Element Type is an "
150148
"integer type and whose Length is equal to the Rank of Tensor.";

test/val/val_tensor_test.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ std::string GenerateModule(const std::string& body) {
7171
%uint_var_1 = OpVariable %uint_ptr_Private Private %uint_1
7272
%var_uint_arr4_1_1_1_1 = OpVariable %uint_arr4_ptr_Private Private %uint_arr4_1_1_1_1
7373
%tensor_uint_4 = OpTypeTensorARM %uint %uint_4
74+
%tensor_uint_spec = OpTypeTensorARM %uint %uint_0_spec
7475
%tensor_float = OpTypeTensorARM %float
7576
%tensor_uint_4_ptr_UniformConstant = OpTypePointer UniformConstant %tensor_uint_4
7677
%tensor_var = OpVariable %tensor_uint_4_ptr_UniformConstant UniformConstant
7778
%tensor_float_ptr_UniformConstant = OpTypePointer UniformConstant %tensor_float
7879
%tensor_var_float_unranked = OpVariable %tensor_float_ptr_UniformConstant UniformConstant
80+
%tensor_uint_spec_ptr_UniformConstant = OpTypePointer UniformConstant %tensor_uint_spec
81+
%tensor_var_spec_rank = OpVariable %tensor_uint_spec_ptr_UniformConstant UniformConstant
7982
)";
8083
const std::string footer = R"(
8184
%fnep = OpFunction %void None %fnty
@@ -135,7 +138,8 @@ TEST_F(ValidateTensor, ValidTypeElementTypeAndRank) {
135138

136139
TEST_F(ValidateTensor, ValidTypeElementTypeAndRankUsingSpecConstant) {
137140
const std::string src = R"(
138-
%test_type = OpTypeTensorARM %uint %uint_0_spec
141+
%rank_spec = OpSpecConstant %uint 0
142+
%test_type = OpTypeTensorARM %uint %rank_spec
139143
)";
140144
std::string spvasm = GenerateModule(src);
141145
CompileSuccessfully(spvasm, SPVENV);
@@ -563,6 +567,20 @@ TEST_F(ValidateTensor, ValidTensorReadArray) {
563567
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPVENV));
564568
}
565569

570+
TEST_F(ValidateTensor, ValidTensorReadSpecConstantRank) {
571+
const std::string src = R"(
572+
%fn = OpFunction %void None %fnty
573+
%label1 = OpLabel
574+
%tensor = OpLoad %tensor_uint_spec %tensor_var_spec_rank
575+
%val = OpTensorReadARM %uint %tensor %uint_arr4_1_1_1_1
576+
OpReturn
577+
OpFunctionEnd
578+
)";
579+
std::string spvasm = GenerateModule(src);
580+
CompileSuccessfully(spvasm, SPVENV);
581+
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPVENV));
582+
}
583+
566584
TEST_F(ValidateTensor, InvalidTensorReadResultTypeVoid) {
567585
const std::string src = R"(
568586
%fn = OpFunction %void None %fnty
@@ -840,6 +858,20 @@ TEST_F(ValidateTensor, ValidTensorWriteArray) {
840858
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPVENV));
841859
}
842860

861+
TEST_F(ValidateTensor, ValidTensorWriteSpecConstantRank) {
862+
const std::string src = R"(
863+
%fn = OpFunction %void None %fnty
864+
%label1 = OpLabel
865+
%tensor = OpLoad %tensor_uint_spec %tensor_var_spec_rank
866+
OpTensorWriteARM %tensor %uint_arr4_1_1_1_1 %uint_1
867+
OpReturn
868+
OpFunctionEnd
869+
)";
870+
std::string spvasm = GenerateModule(src);
871+
CompileSuccessfully(spvasm, SPVENV);
872+
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPVENV));
873+
}
874+
843875
TEST_F(ValidateTensor, InvalidTensorWriteObjectNotScalarOrArrayOfScalar) {
844876
const std::string src = R"(
845877
%fn = OpFunction %void None %fnty
@@ -1069,12 +1101,9 @@ TEST_F(ValidateTensor, ValidTensorQuerySize) {
10691101

10701102
TEST_F(ValidateTensor, ValidTensorQuerySizeSpecConstant) {
10711103
const std::string src = R"(
1072-
%tensor_uint_4_spec = OpTypeTensorARM %uint %uint_0_spec
1073-
%tensor_uint_4_spec_ptr_UniformConstant = OpTypePointer UniformConstant %tensor_uint_4_spec
1074-
%tensor_var_spec = OpVariable %tensor_uint_4_spec_ptr_UniformConstant UniformConstant
10751104
%fn = OpFunction %void None %fnty
10761105
%label1 = OpLabel
1077-
%tensor = OpLoad %tensor_uint_4_spec %tensor_var_spec
1106+
%tensor = OpLoad %tensor_uint_spec %tensor_var_spec_rank
10781107
%size = OpTensorQuerySizeARM %uint %tensor %uint_0_spec
10791108
OpReturn
10801109
OpFunctionEnd

0 commit comments

Comments
 (0)