@@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
4848
4949// -----
5050
51+ // No restriction on vector sizes to allow capturing workgroup-sized operations.
52+ // The operations can then be progressively resized through distribution down
53+ // to hardware compatible sizes.
54+
55+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
56+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
57+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
58+ func.func @dpas_large_dims (%lhs: vector <128 x512 xf16 >, %rhs: vector <512 x256 xf16 >,
59+ %acc: vector <128 x256 xf32 >) -> vector <128 x256 xf32 > {
60+ %3 = vector.contract
61+ {index ing_maps = [#map , #map1 , #map2 ],
62+ iterator_types = [" parallel" , " parallel" , " reduction" ],
63+ kind = #vector.kind <add >} %lhs , %rhs , %acc
64+ : vector <128 x512 xf16 >, vector <512 x256 xf16 > into vector <128 x256 xf32 >
65+ return %3 : vector <128 x256 xf32 >
66+ }
67+
68+ // CHECK-LABEL: @dpas_large_dims(
69+ // CHECK-SAME: %[[LHS:.+]]: vector<128x512xf16>,
70+ // CHECK-SAME: %[[RHS:.+]]: vector<512x256xf16>,
71+ // CHECK-SAME: %[[ACC:.+]]: vector<128x256xf32>
72+ // CHECK: %[[DPAS:.+]] = xegpu.dpas
73+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
74+ // CHECK-SAME: {{.*}}-> vector<128x256xf32>
75+ // CHECK: return %[[DPAS]]
76+
77+ // -----
78+
5179// For simplicity, only plain data layouts are currently supported.
5280// VNNI packing is applied later as a separate lowering step.
5381
@@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
138166
139167// CHECK-LABEL: @negative_gemm_transpose_b(
140168// CHECK: vector.contract
141-
142- // -----
143-
144- #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
145- #map1 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
146- #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
147- func.func @negative_n_dim_size (%lhs: vector <8 x16 xf16 >, %rhs: vector <16 x32 xf16 >,
148- %acc: vector <8 x32 xf32 >) -> vector <8 x32 xf32 > {
149- %3 = vector.contract
150- {index ing_maps = [#map , #map1 , #map2 ],
151- iterator_types = [" parallel" , " parallel" , " reduction" ],
152- kind = #vector.kind <add >} %lhs , %rhs , %acc
153- : vector <8 x16 xf16 >, vector <16 x32 xf16 > into vector <8 x32 xf32 >
154- return %3 : vector <8 x32 xf32 >
155- }
156-
157- // CHECK-LABEL: @negative_n_dim_size(
158- // CHECK: vector.contract
0 commit comments