33module {
44 func.func @matmul_transpose_b (%arg0: memref <1024 x1024 xf16 >, %arg1: memref <1024 x1024 xf16 >, %arg2: memref <1024 x1024 xf16 >) {
55 %c0 = arith.constant 0 : index
6- %c32 = arith.constant 32 : index
6+ %c16 = arith.constant 16 : index
7+ %c64 = arith.constant 64 : index
78 %c1024 = arith.constant 1024 : index
8- scf.parallel (%arg3 , %arg4 ) = (%c0 , %c0 ) to (%c1024 , %c1024 ) step (%c32 , %c32 ) {
9- %subview_0 = memref.subview %arg2 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <32 x 32 x f16 , strided <[1024 , 1 ], offset : ?>>
10- %subview_1 = memref.subview %arg0 [%arg3 , 0 ] [32 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
11- %subview_2 = memref.subview %arg1 [%arg4 , 0 ] [32 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
12- linalg.matmul_transpose_b ins (%subview_1 , %subview_2 : memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>, memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>) outs (%subview_0 : memref <32 x 32 x f16 , strided <[1024 , 1 ], offset : ?>>)
9+ scf.parallel (%arg3 , %arg4 ) = (%c0 , %c0 ) to (%c1024 , %c1024 ) step (%c16 , %c64 ) {
10+ %subview_0 = memref.subview %arg2 [%arg3 , %arg4 ] [16 , 64 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <16 x 64 x f16 , strided <[1024 , 1 ], offset : ?>>
11+ %subview_1 = memref.subview %arg0 [%arg3 , 0 ] [16 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <16 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
12+ %subview_2 = memref.subview %arg1 [%arg4 , 0 ] [64 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <64 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
13+ linalg.matmul_transpose_b ins (%subview_1 , %subview_2 : memref <16 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>, memref <64 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>) outs (%subview_0 : memref <16 x 64 x f16 , strided <[1024 , 1 ], offset : ?>>)
1314 scf.reduce
1415 }
1516 return
@@ -19,7 +20,7 @@ module {
1920// CHECK-LABEL: func.func @matmul_transpose_b
2021// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16>
2122
22- // CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32 , %c32 ) {
23+ // CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c16 , %c64 ) {
2324// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}}
2425// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}}
2526// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}}
@@ -43,9 +44,11 @@ module {
4344// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
4445// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
4546// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0]
47+ // CHECK: %[[tB2:.+]] = xegpu.update_nd_offset %[[rootB]], [%c32, %c0]
48+ // CHECK: %[[tB3:.+]] = xegpu.update_nd_offset %[[rootB]], [%c48, %c0]
4649
4750// Create DPAS computation loop over tiled reduction dimension.
48- // CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16
51+ // CHECK: %[[res:.+]]:13 = scf.for{{.*}}%c0 to %c1024 step %c16
4952// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]]
5053// CHECK-SAME: {
5154
@@ -66,10 +69,10 @@ module {
6669
6770// Extract DPAS-sized chunks from larger loaded tile A.
6871// Tile B is already in the correct shape.
69- // CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16 > to vector<512xf16 >
70- // CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16 > to vector<128xf16>
72+ // CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<16x16xf16 > to vector<256xf16 >
73+ // CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<256xf16 > to vector<128xf16>
7174// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
72- // CHECK-COUNT-3 : vector.extract_strided_slice
75+ // CHECK-COUNT-1 : vector.extract_strided_slice
7376
7477// Perform DPAS computation.
7578// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]
0 commit comments