@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
3636 spirv.ReturnValue %result : vector <4 xf32 >
3737 }
3838
39+ // CHECK-LABEL: @vector_times_matrix_1
40+ spirv.func @vector_times_matrix_1 (%arg0: vector <3 xf32 >, %arg1: !spirv.matrix <4 x vector <3 xf32 >>) -> vector <4 xf32 > " None" {
41+ // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
42+ %result = spirv.VectorTimesMatrix %arg0 , %arg1 : vector <3 xf32 >, !spirv.matrix <4 x vector <3 xf32 >> -> vector <4 xf32 >
43+ spirv.ReturnValue %result : vector <4 xf32 >
44+ }
45+
3946 // CHECK-LABEL: @matrix_times_matrix_1
4047 spirv.func @matrix_times_matrix_1 (%arg0: !spirv.matrix <3 x vector <3 xf32 >>, %arg1: !spirv.matrix <3 x vector <3 xf32 >>) -> !spirv.matrix <3 x vector <3 xf32 >> " None" {
4148 // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -123,7 +130,6 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
123130 return
124131}
125132
126-
127133// -----
128134
129135func.func @matrix_times_matrix_component_type_mismatch_2 (%arg0 : !spirv.matrix <3 x vector <3 xf64 >>, %arg1 : !spirv.matrix <3 x vector <3 xf32 >>){
@@ -155,3 +161,35 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
155161 %result = spirv.MatrixTimesVector %arg0 , %arg1 : !spirv.matrix <4 x vector <3 xf32 >>, vector <3 xf32 > -> vector <3 xf32 >
156162 return
157163}
164+
165+ // -----
166+
167+ func.func @vector_times_matrix_vector_matrix_mismatch (%arg0: vector <4 xf32 >, %arg1: !spirv.matrix <4 x vector <3 xf32 >>) {
168+ // expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}}
169+ %result = spirv.VectorTimesMatrix %arg0 , %arg1 : vector <4 xf32 >, !spirv.matrix <4 x vector <3 xf32 >> -> vector <3 xf32 >
170+ return
171+ }
172+
173+ // -----
174+
175+ func.func @vector_times_matrix_result_matrix_mismatch (%arg0: vector <3 xf32 >, %arg1: !spirv.matrix <4 x vector <3 xf32 >>) {
176+ // expected-error @+1 {{number of columns in matrix must equal the number of components in result}}
177+ %result = spirv.VectorTimesMatrix %arg0 , %arg1 : vector <3 xf32 >, !spirv.matrix <4 x vector <3 xf32 >> -> vector <3 xf32 >
178+ return
179+ }
180+
181+ // -----
182+
183+ func.func @vector_times_matrix_vector_type_mismatch (%arg0: vector <3 xi32 >, %arg1: !spirv.matrix <4 x vector <3 xf32 >>) {
184+ // expected-error @+1 {{vector must be a vector with the same component type as the component type in result}}
185+ %result = spirv.VectorTimesMatrix %arg0 , %arg1 : vector <3 xi32 >, !spirv.matrix <4 x vector <3 xf32 >> -> vector <4 xf32 >
186+ return
187+ }
188+
189+ // -----
190+
191+ func.func @vector_times_matrix_matrix_type_mismatch (%arg0: vector <3 xf32 >, %arg1: !spirv.matrix <4 x vector <3 xf16 >>) {
192+ // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
193+ %result = spirv.VectorTimesMatrix %arg0 , %arg1 : vector <3 xf32 >, !spirv.matrix <4 x vector <3 xf16 >> -> vector <4 xf32 >
194+ return
195+ }
0 commit comments