33// RUN: -test-transform-dialect-erase-schedule \
44// RUN: | FileCheck %s
55
6- memref.global " private" @bufferLhsGlobal : memref <64 x 8 x f32 , #gpu.address_space <workgroup >>
7- memref.global " private" @bufferRhsGlobal : memref <8 x 128 x f32 , #gpu.address_space <workgroup >>
6+ memref.global " private" @bufferLhsGlobal : memref <64 x 32 x f32 , #gpu.address_space <workgroup >>
7+ memref.global " private" @bufferRhsGlobal : memref <8 x 32 x f32 , #gpu.address_space <workgroup >>
88
99// CHECK-LABEL: func.func @main()
1010func.func @main () {
1111 %c1 = arith.constant 1 : index
1212 %c128 = arith.constant 128 : index
1313
1414 %0 = gpu.wait async
15- %memref , %asyncToken = gpu.alloc async [%0 ] () : memref <64 x 8 x f32 >
16- %memref_1 , %asyncToken_2 = gpu.alloc async [%0 ] () : memref <8 x 128 x f32 >
15+ %memref , %asyncToken = gpu.alloc async [%0 ] () : memref <64 x 32 x f32 >
16+ %memref_1 , %asyncToken_2 = gpu.alloc async [%0 ] () : memref <8 x 32 x f32 >
1717
18- // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32 > to memref<*xf32>
18+ // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32 > to memref<*xf32>
1919 // CHECK: %[[c64:.*]] = arith.constant 64 : index
20- // CHECK: %[[c8 :.*]] = arith.constant 8 : index
21- // CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8 ]]]
22- // CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x8xf32 , #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
23- // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32 > to memref<*xf32>
20+ // CHECK: %[[c32 :.*]] = arith.constant 32 : index
21+ // CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c32 ]]]
22+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x32xf32 , #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
23+ // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32 > to memref<*xf32>
2424 // CHECK: %[[c8_2:.*]] = arith.constant 8 : index
25- // CHECK: %[[c128_2 :.*]] = arith.constant 128 : index
26- // CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2 ]]]
27- // CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x128xf32 , #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
25+ // CHECK: %[[c32_2 :.*]] = arith.constant 32 : index
26+ // CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c32_2 ]]]
27+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x32xf32 , #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
2828 // CHECK: gpu.launch
2929 gpu.launch blocks (%bx , %by , %bz ) in (%grid_x = %c1 , %grid_y = %c1 , %grid_z = %c1 )
3030 threads (%tx , %ty , %tz ) in (%block_x = %c128 , %block_y = %c1 , %block_z = %c1 ) {
31- // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32 , #gpu.address_space<workgroup>>
32- // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32 , #gpu.address_space<workgroup>>
33- %out = memref.get_global @bufferLhsGlobal : memref <64 x 8 x f32 , #gpu.address_space <workgroup >>
34- %out_1 = memref.get_global @bufferRhsGlobal : memref <8 x 128 x f32 , #gpu.address_space <workgroup >>
31+ // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x32xf32 , #gpu.address_space<workgroup>>
32+ // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x32xf32 , #gpu.address_space<workgroup>>
33+ %out = memref.get_global @bufferLhsGlobal : memref <64 x 32 x f32 , #gpu.address_space <workgroup >>
34+ %out_1 = memref.get_global @bufferRhsGlobal : memref <8 x 32 x f32 , #gpu.address_space <workgroup >>
3535
3636 // CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
3737 // CHECK: nvgpu.mbarrier.init %[[B]][%{{.*}}], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
@@ -45,18 +45,18 @@ func.func @main() {
4545 //
4646 // CHECK: %[[c0_7:.*]] = arith.constant 0 : index
4747 // CHECK: nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]][%{{.*}}] to %[[G1]]
48- // CHECK-SAME: : <tensor = memref<64x8xf32 , #gpu.address_space<workgroup>>,
48+ // CHECK-SAME: : <tensor = memref<64x32xf32 , #gpu.address_space<workgroup>>,
4949 // CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
50- // CHECK-SAME: -> memref<64x8xf32 , #gpu.address_space<workgroup>>
50+ // CHECK-SAME: -> memref<64x32xf32 , #gpu.address_space<workgroup>>
5151 //
5252 // CHECK: %[[c0_8:.*]] = arith.constant 0 : index
5353 // CHECK: nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]][%{{.*}}] to %[[G2]]
54- // CHECK-SAME: : <tensor = memref<8x128xf32 , #gpu.address_space<workgroup>>,
54+ // CHECK-SAME: : <tensor = memref<8x32xf32 , #gpu.address_space<workgroup>>,
5555 // CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
56- // CHECK-SAME: -> memref<8x128xf32 , #gpu.address_space<workgroup>>
56+ // CHECK-SAME: -> memref<8x32xf32 , #gpu.address_space<workgroup>>
5757 //
58- // CHECK: %[[c6144 :.*]] = arith.constant 6144 : index
59- // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144 ]] : <memorySpace = #gpu.address_space<workgroup>
58+ // CHECK: %[[c9216 :.*]] = arith.constant 9216 : index
59+ // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c9216 ]] : <memorySpace = #gpu.address_space<workgroup>
6060 // CHECK: } else {
6161 // CHECK: %[[c0_7:.*]] = arith.constant 0 : index
6262 // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
@@ -67,8 +67,8 @@ func.func @main() {
6767 // CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
6868
6969 /// Both copies are matched and end up in the same async group.
70- linalg.copy ins (%memref: memref <64 x 8 x f32 >) outs (%out: memref <64 x 8 x f32 , #gpu.address_space <workgroup >>)
71- linalg.copy ins (%memref_1: memref <8 x 128 x f32 >) outs (%out_1: memref <8 x 128 x f32 , #gpu.address_space <workgroup >>)
70+ linalg.copy ins (%memref: memref <64 x 32 x f32 >) outs (%out: memref <64 x 32 x f32 , #gpu.address_space <workgroup >>)
71+ linalg.copy ins (%memref_1: memref <8 x 32 x f32 >) outs (%out_1: memref <8 x 32 x f32 , #gpu.address_space <workgroup >>)
7272
7373 gpu.terminator
7474 }
0 commit comments