Skip to content

Commit d4671f6

Browse files
spirv-val: Add OpGroupAsyncCopy and OpGroupWaitEvents (#6519)
adds `OpGroupAsyncCopy` and `OpGroupWaitEvents` (and a test for every error added)
1 parent 0a7e286 commit d4671f6

File tree

2 files changed

+345
-3
lines changed

2 files changed

+345
-3
lines changed

source/val/validate_group.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,112 @@ spv_result_t ValidateGroupInt(ValidationState_t& _, const Instruction* inst) {
8787
return SPV_SUCCESS;
8888
}
8989

90+
spv_result_t ValidateGroupAsyncCopy(ValidationState_t& _,
91+
const Instruction* inst) {
92+
if (_.FindDef(inst->type_id())->opcode() != spv::Op::OpTypeEvent) {
93+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
94+
<< "The result type must be OpTypeEvent.";
95+
}
96+
97+
const uint32_t destination = _.GetOperandTypeId(inst, 3);
98+
const Instruction* destination_pointer = _.FindDef(destination);
99+
if (destination_pointer->opcode() != spv::Op::OpTypePointer) {
100+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
101+
<< "Expected Destination to be a pointer.";
102+
}
103+
const auto destination_sc =
104+
destination_pointer->GetOperandAs<spv::StorageClass>(1);
105+
if (destination_sc != spv::StorageClass::Workgroup &&
106+
destination_sc != spv::StorageClass::CrossWorkgroup) {
107+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
108+
<< "Expected Destination to be a pointer with storage class "
109+
"Workgroup or CrossWorkgroup.";
110+
}
111+
const uint32_t destination_type =
112+
destination_pointer->GetOperandAs<uint32_t>(2);
113+
if (!_.IsIntScalarOrVectorType(destination_type) &&
114+
!_.IsFloatScalarOrVectorType(destination_type)) {
115+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
116+
<< "Expected Destination to be a pointer to scalar or vector of "
117+
"floating-point type or integer type.";
118+
}
119+
120+
const uint32_t source = _.GetOperandTypeId(inst, 4);
121+
const Instruction* source_pointer = _.FindDef(source);
122+
const auto source_sc = source_pointer->GetOperandAs<spv::StorageClass>(1);
123+
const uint32_t source_type = source_pointer->GetOperandAs<uint32_t>(2);
124+
if (destination_type != source_type) {
125+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
126+
<< "Expected Destination and Source to be the same type.";
127+
}
128+
129+
if (destination_sc == spv::StorageClass::Workgroup &&
130+
source_sc != spv::StorageClass::CrossWorkgroup) {
131+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
132+
<< "If Destination storage class is Workgroup, then the Source "
133+
"storage class must be CrossWorkgroup.";
134+
} else if (destination_sc == spv::StorageClass::CrossWorkgroup &&
135+
source_sc != spv::StorageClass::Workgroup) {
136+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
137+
<< "If Destination storage class is CrossWorkgroup, then the Source "
138+
"storage class must be Workgroup.";
139+
}
140+
141+
const bool is_physical_64 =
142+
_.addressing_model() == spv::AddressingModel::Physical64;
143+
const uint32_t bit_width = is_physical_64 ? 64 : 32;
144+
145+
const uint32_t num_elements_type =
146+
_.GetTypeId(inst->GetOperandAs<uint32_t>(5));
147+
if (!_.IsIntScalarType(num_elements_type, bit_width)) {
148+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
149+
<< "NumElements must be a " << bit_width
150+
<< "-bit int scalar when Addressing Model is "
151+
<< (is_physical_64 ? "Physical64" : "Physical32");
152+
}
153+
154+
const uint32_t stride_type = _.GetTypeId(inst->GetOperandAs<uint32_t>(6));
155+
if (!_.IsIntScalarType(stride_type, bit_width)) {
156+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
157+
<< "Stride must be a " << bit_width
158+
<< "-bit int scalar when Addressing Model is "
159+
<< (is_physical_64 ? "Physical64" : "Physical32");
160+
}
161+
162+
const uint32_t event = _.GetOperandTypeId(inst, 7);
163+
const Instruction* event_type = _.FindDef(event);
164+
if (event_type->opcode() != spv::Op::OpTypeEvent) {
165+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
166+
<< "Expected Event to be type OpTypeEvent.";
167+
}
168+
169+
return SPV_SUCCESS;
170+
}
171+
172+
spv_result_t ValidateGroupWaitEvents(ValidationState_t& _,
173+
const Instruction* inst) {
174+
const uint32_t num_events_id = _.GetOperandTypeId(inst, 1);
175+
if (!_.IsIntScalarType(num_events_id, 32)) {
176+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
177+
<< "Expected Num Events to be a 32-bit int scalar.";
178+
}
179+
180+
const uint32_t events_id = _.GetOperandTypeId(inst, 2);
181+
const Instruction* var_pointer = _.FindDef(events_id);
182+
if (var_pointer->opcode() != spv::Op::OpTypePointer) {
183+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
184+
<< "Expected Events List to be a pointer.";
185+
}
186+
const Instruction* event_list_type =
187+
_.FindDef(var_pointer->GetOperandAs<uint32_t>(2));
188+
if (event_list_type->opcode() != spv::Op::OpTypeEvent) {
189+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
190+
<< "Expected Events List to be a pointer to OpTypeEvent.";
191+
}
192+
193+
return SPV_SUCCESS;
194+
}
195+
90196
} // namespace
91197

92198
spv_result_t GroupPass(ValidationState_t& _, const Instruction* inst) {
@@ -108,6 +214,10 @@ spv_result_t GroupPass(ValidationState_t& _, const Instruction* inst) {
108214
case spv::Op::OpGroupUMax:
109215
case spv::Op::OpGroupSMax:
110216
return ValidateGroupInt(_, inst);
217+
case spv::Op::OpGroupAsyncCopy:
218+
return ValidateGroupAsyncCopy(_, inst);
219+
case spv::Op::OpGroupWaitEvents:
220+
return ValidateGroupWaitEvents(_, inst);
111221
default:
112222
break;
113223
}

test/val/val_group_test.cpp

Lines changed: 235 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using ::testing::HasSubstr;
2525

2626
using ValidateGroup = spvtest::ValidateBase<bool>;
2727

28-
std::string GenerateShaderCode(const std::string& body) {
28+
std::string GenerateShaderCode(const std::string& body, bool is_64_bit = true) {
2929
std::ostringstream ss;
3030
ss << R"(
3131
OpCapability Kernel
@@ -34,19 +34,43 @@ OpCapability Linkage
3434
OpCapability Groups
3535
OpCapability Float64
3636
OpCapability Int64
37-
OpMemoryModel Physical64 OpenCL
37+
)";
38+
if (is_64_bit) {
39+
ss << "OpMemoryModel Physical64 OpenCL";
40+
} else {
41+
ss << "OpMemoryModel Physical32 OpenCL";
42+
}
43+
ss << R"(
3844
OpEntryPoint Kernel %main "main"
45+
%bool = OpTypeBool
3946
%float = OpTypeFloat 32
4047
%float64 = OpTypeFloat 64
4148
%uint = OpTypeInt 32 0
4249
%uint64 = OpTypeInt 64 0
50+
%null_uint = OpConstantNull %uint
4351
%uint_0 = OpConstant %uint 0
52+
%uint_1 = OpConstant %uint 1
4453
%uint_2 = OpConstant %uint 2
54+
%uint64_1 = OpConstant %uint64 1
4555
%float_2 = OpConstant %float 2
4656
%uint_array = OpTypeArray %uint %uint_2
4757
%void = OpTypeVoid
58+
%event = OpTypeEvent
59+
%null_event = OpConstantNull %event
60+
61+
%workgroup_float_ptr = OpTypePointer Workgroup %float
62+
%workgroup_float_var = OpVariable %workgroup_float_ptr Workgroup
63+
%workgroup_bool_ptr = OpTypePointer Workgroup %bool
64+
%workgroup_bool_var = OpVariable %workgroup_bool_ptr Workgroup
65+
%cross_float_ptr = OpTypePointer CrossWorkgroup %float
66+
%cross_float_var = OpVariable %cross_float_ptr CrossWorkgroup
67+
%cross_uint_ptr = OpTypePointer CrossWorkgroup %uint
68+
%cross_uint_var = OpVariable %cross_uint_ptr CrossWorkgroup
69+
%uniform_float_ptr = OpTypePointer UniformConstant %float
70+
%uniform_float_var = OpVariable %uniform_float_ptr UniformConstant
71+
%func_event_ptr = OpTypePointer Function %event
72+
4873
%fn = OpTypeFunction %void
49-
%bool = OpTypeBool
5074
%true = OpConstantTrue %bool
5175
%main = OpFunction %void None %fn
5276
%label = OpLabel
@@ -211,6 +235,214 @@ TEST_F(ValidateGroup, BroadcastMismatch) {
211235
HasSubstr("The type of Value must match the Result type"));
212236
}
213237

238+
TEST_F(ValidateGroup, AsyncCopyWaitEventsGood) {
239+
const std::string ss = R"(
240+
OpCapability Kernel
241+
OpCapability Addresses
242+
OpCapability Int64
243+
OpCapability Int8
244+
OpCapability Linkage
245+
OpMemoryModel Physical64 OpenCL
246+
OpEntryPoint Kernel %async_example "async_example"
247+
OpExecutionMode %async_example ContractionOff
248+
OpDecorate %26 Alignment 4
249+
OpDecorate %29 Alignment 8
250+
OpDecorate %async_example_local_data Alignment 4
251+
%float = OpTypeFloat 32
252+
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
253+
%void = OpTypeVoid
254+
%5 = OpTypeFunction %void %_ptr_CrossWorkgroup_float
255+
%spirv_Event = OpTypeEvent
256+
%_ptr_Workgroup_float = OpTypePointer Workgroup %float
257+
%ulong = OpTypeInt 64 0
258+
%uint = OpTypeInt 32 0
259+
%_ptr_Function_spirv_Event = OpTypePointer Function %spirv_Event
260+
%uint_64 = OpConstant %uint 64
261+
%_arr_float_uint_64 = OpTypeArray %float %uint_64
262+
%_ptr_Workgroup__arr_float_uint_64 = OpTypePointer Workgroup %_arr_float_uint_64
263+
%ulong_1 = OpConstant %ulong 1
264+
%ulong_64 = OpConstant %ulong 64
265+
%uint_1 = OpConstant %uint 1
266+
%uint_2 = OpConstant %uint 2
267+
%uchar = OpTypeInt 8 0
268+
%_ptr_Function_uchar = OpTypePointer Function %uchar
269+
%async_example_local_data = OpVariable %_ptr_Workgroup__arr_float_uint_64 Workgroup
270+
%24 = OpConstantNull %spirv_Event
271+
%async_example = OpFunction %void None %5
272+
%26 = OpFunctionParameter %_ptr_CrossWorkgroup_float
273+
%54 = OpLabel
274+
%29 = OpVariable %_ptr_Function_spirv_Event Function
275+
%30 = OpBitcast %_ptr_Workgroup_float %async_example_local_data
276+
%31 = OpBitcast %_ptr_Function_uchar %29
277+
%32 = OpGroupAsyncCopy %spirv_Event %uint_2 %30 %26 %ulong_64 %ulong_1 %24
278+
OpStore %29 %32 Aligned 8
279+
OpGroupWaitEvents %uint_2 %uint_1 %29
280+
OpReturn
281+
OpFunctionEnd
282+
)";
283+
CompileSuccessfully(ss);
284+
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
285+
}
286+
287+
TEST_F(ValidateGroup, AsyncCopyResultType) {
288+
const std::string ss = R"(
289+
%a = OpGroupAsyncCopy %uint %uint_2 %workgroup_float_var %cross_float_var %uint64_1 %uint64_1 %null_event
290+
)";
291+
CompileSuccessfully(GenerateShaderCode(ss));
292+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
293+
EXPECT_THAT(getDiagnosticString(),
294+
HasSubstr("The result type must be OpTypeEvent"));
295+
}
296+
297+
TEST_F(ValidateGroup, AsyncCopyDestinationPointer) {
298+
const std::string ss = R"(
299+
%a = OpGroupAsyncCopy %event %uint_2 %null_uint %cross_float_var %uint64_1 %uint64_1 %null_event
300+
)";
301+
CompileSuccessfully(GenerateShaderCode(ss));
302+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
303+
EXPECT_THAT(getDiagnosticString(),
304+
HasSubstr("Expected Destination to be a pointer"));
305+
}
306+
307+
TEST_F(ValidateGroup, AsyncCopyDestinationUniform) {
308+
const std::string ss = R"(
309+
%a = OpGroupAsyncCopy %event %uint_2 %uniform_float_var %cross_float_var %uint64_1 %uint64_1 %null_event
310+
)";
311+
CompileSuccessfully(GenerateShaderCode(ss));
312+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
313+
EXPECT_THAT(getDiagnosticString(),
314+
HasSubstr("Expected Destination to be a pointer with storage "
315+
"class Workgroup or CrossWorkgroup"));
316+
}
317+
318+
TEST_F(ValidateGroup, AsyncCopyDestinationBool) {
319+
const std::string ss = R"(
320+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_bool_var %cross_float_var %uint64_1 %uint64_1 %null_event
321+
)";
322+
CompileSuccessfully(GenerateShaderCode(ss));
323+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
324+
EXPECT_THAT(getDiagnosticString(),
325+
HasSubstr("Expected Destination to be a pointer to scalar or "
326+
"vector of floating-point type or integer type"));
327+
}
328+
329+
TEST_F(ValidateGroup, AsyncCopyDestinationSourceTypes) {
330+
const std::string ss = R"(
331+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_float_var %cross_uint_var %uint64_1 %uint64_1 %null_event
332+
)";
333+
CompileSuccessfully(GenerateShaderCode(ss));
334+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
335+
EXPECT_THAT(getDiagnosticString(),
336+
HasSubstr("Expected Destination and Source to be the same type"));
337+
}
338+
339+
TEST_F(ValidateGroup, AsyncCopyBothWorkgroup) {
340+
const std::string ss = R"(
341+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_float_var %workgroup_float_var %uint64_1 %uint64_1 %null_event
342+
)";
343+
CompileSuccessfully(GenerateShaderCode(ss));
344+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
345+
EXPECT_THAT(getDiagnosticString(),
346+
HasSubstr("If Destination storage class is Workgroup, then the "
347+
"Source storage class must be CrossWorkgroup."));
348+
}
349+
350+
TEST_F(ValidateGroup, AsyncCopyBothCrossWorkgroup) {
351+
const std::string ss = R"(
352+
%a = OpGroupAsyncCopy %event %uint_2 %cross_float_var %cross_float_var %uint64_1 %uint64_1 %null_event
353+
)";
354+
CompileSuccessfully(GenerateShaderCode(ss));
355+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
356+
EXPECT_THAT(getDiagnosticString(),
357+
HasSubstr("If Destination storage class is CrossWorkgroup, then "
358+
"the Source storage class must be Workgroup"));
359+
}
360+
361+
TEST_F(ValidateGroup, AsyncCopyEventType) {
362+
const std::string ss = R"(
363+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_float_var %cross_float_var %uint64_1 %uint64_1 %null_uint
364+
)";
365+
CompileSuccessfully(GenerateShaderCode(ss));
366+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
367+
EXPECT_THAT(getDiagnosticString(),
368+
HasSubstr("Expected Event to be type OpTypeEvent"));
369+
}
370+
371+
TEST_F(ValidateGroup, AsyncCopyNumElement32Bit) {
372+
const std::string ss = R"(
373+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_float_var %cross_float_var %uint_1 %uint64_1 %null_event
374+
)";
375+
CompileSuccessfully(GenerateShaderCode(ss));
376+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
377+
EXPECT_THAT(getDiagnosticString(),
378+
HasSubstr("NumElements must be a 64-bit int scalar when "
379+
"Addressing Model is Physical64"));
380+
}
381+
382+
TEST_F(ValidateGroup, AsyncCopyStride32Bit) {
383+
const std::string ss = R"(
384+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_float_var %cross_float_var %uint64_1 %uint_1 %null_event
385+
)";
386+
CompileSuccessfully(GenerateShaderCode(ss));
387+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
388+
EXPECT_THAT(getDiagnosticString(),
389+
HasSubstr("Stride must be a 64-bit int scalar when Addressing "
390+
"Model is Physical64"));
391+
}
392+
393+
TEST_F(ValidateGroup, AsyncCopyNumElement64Bit) {
394+
const std::string ss = R"(
395+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_float_var %cross_float_var %uint64_1 %uint_1 %null_event
396+
)";
397+
CompileSuccessfully(GenerateShaderCode(ss, false));
398+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
399+
EXPECT_THAT(getDiagnosticString(),
400+
HasSubstr("NumElements must be a 32-bit int scalar when "
401+
"Addressing Model is Physical32"));
402+
}
403+
404+
TEST_F(ValidateGroup, AsyncCopyStride64Bit) {
405+
const std::string ss = R"(
406+
%a = OpGroupAsyncCopy %event %uint_2 %workgroup_float_var %cross_float_var %uint_1 %uint64_1 %null_event
407+
)";
408+
CompileSuccessfully(GenerateShaderCode(ss, false));
409+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
410+
EXPECT_THAT(getDiagnosticString(),
411+
HasSubstr("Stride must be a 32-bit int scalar when Addressing "
412+
"Model is Physical32"));
413+
}
414+
415+
TEST_F(ValidateGroup, GroupWaitEventsNumEvents) {
416+
const std::string ss = R"(
417+
%a = OpVariable %func_event_ptr Function
418+
OpGroupWaitEvents %uint_2 %uint64_1 %a
419+
)";
420+
CompileSuccessfully(GenerateShaderCode(ss));
421+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
422+
EXPECT_THAT(getDiagnosticString(),
423+
HasSubstr("Expected Num Events to be a 32-bit int scalar"));
424+
}
425+
426+
TEST_F(ValidateGroup, GroupWaitEventsEventList) {
427+
const std::string ss = R"(
428+
OpGroupWaitEvents %uint_2 %uint_1 %null_uint
429+
)";
430+
CompileSuccessfully(GenerateShaderCode(ss));
431+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
432+
EXPECT_THAT(getDiagnosticString(),
433+
HasSubstr("Expected Events List to be a pointer"));
434+
}
435+
436+
TEST_F(ValidateGroup, GroupWaitEventsEventListType) {
437+
const std::string ss = R"(
438+
OpGroupWaitEvents %uint_2 %uint_1 %uniform_float_var
439+
)";
440+
CompileSuccessfully(GenerateShaderCode(ss));
441+
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
442+
EXPECT_THAT(getDiagnosticString(),
443+
HasSubstr("Expected Events List to be a pointer to OpTypeEvent"));
444+
}
445+
214446
} // namespace
215447
} // namespace val
216448
} // namespace spvtools

0 commit comments

Comments
 (0)