|
5 | 5 | #mma32_scaled = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [32, 32, 64], isTransposed = true}> |
6 | 6 | #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> |
7 | 7 | #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> |
| 8 | +#padding = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [128, 64]}> |
| 9 | +#padding_vec1 = #ttg.padded_shared<[1:+4] {order = [0, 1], shape = [128, 64]}> |
8 | 10 | #smem = #ttg.shared_memory |
9 | 11 |
|
10 | 12 | #linear_ds_tr_tile_out = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [32, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> |
@@ -688,4 +690,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr |
688 | 690 | tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr_tile_invalid> |
689 | 691 | tt.return |
690 | 692 | } |
| 693 | + |
| 694 | + // CHECK-LABEL: ds_transpose_with_padding |
| 695 | + tt.func @ds_transpose_with_padding(%arg0: !ttg.memdesc<128x64xf16, #padding, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { |
| 696 | + // CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> |
| 697 | + // CHECK-NOT: rocdl.ds.read.tr16.b64 |
| 698 | + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> |
| 699 | + |
| 700 | + %ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> |
| 701 | + tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> |
| 702 | + tt.return |
| 703 | + } |
| 704 | + |
| 705 | + // CHECK-LABEL: ds_transpose_padding_interval_too_small |
| 706 | + tt.func @ds_transpose_padding_interval_too_small(%arg0: !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { |
| 707 | + // CHECK-NOT: rocdl.ds.read.tr16.b64 |
| 708 | + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> |
| 709 | + |
| 710 | + %ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> |
| 711 | + tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> |
| 712 | + tt.return |
| 713 | + } |
691 | 714 | } |
0 commit comments