|
57 | 57 |
|
58 | 58 | func.func private @printMemrefF32(memref<*xf32>) |
59 | 59 |
|
60 | | -memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64} |
61 | 60 | memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64} |
62 | 61 |
|
63 | 62 | func.func @main() { |
@@ -148,12 +147,11 @@ func.func @main() { |
148 | 147 | %c57344 = arith.constant 57344 : index |
149 | 148 | %c40960 = arith.constant 40960 : index |
150 | 149 |
|
151 | | - %tidx = gpu.thread_id x |
152 | | - %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> |
153 | | - %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3> |
154 | | - %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3> |
155 | | - %rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> |
| 150 | + %tidx = gpu.thread_id x |
156 | 151 | %dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> |
| 152 | + %lhsShmem = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x128x64xf16, #gpu.address_space<workgroup>> |
| 153 | + %rhsShmem = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x64x128xf16, #gpu.address_space<workgroup>> |
| 154 | + |
157 | 155 | // Step 1. [GPU] Create Async Transactional Barriers (mbarriers) |
158 | 156 | %barrier = nvgpu.mbarrier.create -> !barrierType |
159 | 157 | %cnd = arith.cmpi eq, %tidx, %c0 : index |
@@ -202,11 +200,11 @@ func.func @main() { |
202 | 200 | // TMA wait |
203 | 201 | %phase_c0 = arith.constant 0 : i1 |
204 | 202 | nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType |
205 | | - %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3> |
206 | | - %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3> |
| 203 | + %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, #gpu.address_space<workgroup>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>> |
| 204 | + %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, #gpu.address_space<workgroup>> to memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>> |
207 | 205 | // Descriptor WGMMA |
208 | | - %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>> |
209 | | - %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>> |
| 206 | + %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>> |
| 207 | + %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>> |
210 | 208 | // Perform WGMMA 128x128x64 |
211 | 209 | %md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> |
212 | 210 | scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> |
@@ -271,7 +269,7 @@ func.func @main() { |
271 | 269 | vector.print str "Correct Results :" |
272 | 270 | vector.print %correctCount : i32 |
273 | 271 | vector.print str "Incorrect Results :" |
274 | | - vector.print %errorCount : i32 |
| 272 | + vector.print %errorCount : i32 |
275 | 273 |
|
276 | 274 | return |
277 | 275 | } |
0 commit comments