|
| 1 | +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=83' --convert-nv-gpu-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 -S | llc -mtriple nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx83 | FileCheck --dump-input-context=20 %s |
| 2 | + |
| 3 | +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> |
| 4 | +#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth=4}> |
| 5 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { |
| 6 | +// CHECK-LABEL: cvt_mma_to_dot_fp8 |
| 7 | + tt.func @cvt_mma_to_dot_fp8(%ptr : !llvm.ptr, %arg0: tensor<128x64xf8E5M2, #mma>) { |
| 8 | + |
| 9 | + // As there are 64 elements per lane, we don't use variables to track them. |
| 10 | + |
| 11 | + // CHECK-COUNT-64: ld.param.b8 |
| 12 | + |
| 13 | + // Intra-warp layout conversions can be viewed as a permutation of register |
| 14 | + // and lane basis vectors. This can be read off from the linear layouts: |
| 15 | + // |
| 16 | + // #mma: register: [[0,1], [8,0], [0,8], [0,16], [0,32], [64,0]] |
| 17 | + // lane: [[0,2], [0,4], [1,0], [2,0], [4,0]] |
| 18 | + // warp: [[16,0], [32,0]] |
| 19 | + // |
| 20 | + // #dot_op: register: [[0,1], [0,2], [8,0], [0,16], [0,32], [64,0]] |
| 21 | + // lane: [[0,4], [0,8], [1,0], [2,0], [4,0]] |
| 22 | + // warp: [[16,0], [32,0]] |
| 23 | + // |
| 24 | + // The layout conversion is described by the permutation (r1 r2 l1 l0), |
| 25 | + // which factors as (r1 l1)(l0 l1)(r1 r2). |
| 26 | + // |
| 27 | + // Register basis vectors correspond to the bits of the indices of the 64 |
| 28 | + // separate registers which hold the original elements. Since we end up |
| 29 | + // packing 4 elements per register, we end up with only 16 registers in |
| 30 | + // total before shuffling. The `transferWithinWarp` implementation handles |
| 31 | + // register packing by ensuring that elements are packed together only if |
| 32 | + // under the layout conversion, they end up in the same destination lane. |
| 33 | + // To do this, it rearranges the 64 registers so that it can pack 4 |
| 34 | + // consecutive elements at a time according to their new register index. |
| 35 | + // |
| 36 | + // The transposition (r1 l1) above indicates that intially, elements with |
| 37 | + // register indices whose r1 bit is on are to be moved to new lanes. We thus |
| 38 | + // need to rearrange the registers. The algorithm chooses the next register |
| 39 | + // bit > 1 which is not used in a mixed transposition. In this case, |
| 40 | + // that bit is r2. Algebrically, this corresponds to conjugating the |
| 41 | + // permutation with (r1 r2). This produces (r1 r2)(r2 l1)(l0 l1). The new |
| 42 | + // (r1 r2) at the end rearranges elements after unpacking, and only |
| 43 | + // (r2 l1)(l0 l1) matters for tracking the movement of the packed registers. |
| 44 | + // From the point of view of the packed registers, the symbol `r2` now |
| 45 | + // corresponds to the 0th bit of a (packed) register's index. |
| 46 | + // |
| 47 | + // The transposition (r2 l1) is a bit swap which is implemented in-place as: |
| 48 | + // 1. r2 ^= l1 |
| 49 | + // 2. l1 ^= r2 |
| 50 | + // 3. r2 ^= l1. |
| 51 | + // The algorithm conjugates (l0 l1) through the first two stages to produce: |
| 52 | + // 1. r2 ^= l0 |
| 53 | + // 2a. l0 ^= r2 |
| 54 | + // 2b. (l0 l1) |
| 55 | + // 3. r2 ^= l1. |
| 56 | + // The first step is to get the value of l0. |
| 57 | + |
| 58 | + // CHECK: mov.u32 [[TID:%.*]], %tid.x; |
| 59 | + // CHECK: and.b32 [[L0_VAL:%.*]], [[TID]], 1; |
| 60 | + // CHECK: setp.eq.s32 [[L0_OFF:%.*]], [[L0_VAL]], 0; |
| 61 | + |
| 62 | + // This is used to perform 16 independent selects in stage 1. |
| 63 | + |
| 64 | + // CHECK-COUNT-16: selp.b32 {{.*}}, {{.*}}, [[L0_OFF]]; |
| 65 | + |
| 66 | + // Next, we apply (l0 l1) to the lane id to get the base source lane for |
| 67 | + // the index shuffles. This is step 2b above, but since we must specify |
| 68 | + // the *source* lane for a warp-shuffle, it gets applied first in practice: |
| 69 | + // |
| 70 | + // dstLane = ((l0 l1) \circ (l0 ^= r2))(srcLane) |
| 71 | + // srcLane = ((l0 ^= r2) \circ (l0 l1))(dstLane) |
| 72 | + // |
| 73 | + // To apply (l0 l1), we use a compile-time mask to collect the fixed bits, |
| 74 | + // and then we OR it with the shifted l0 and l1 values. |
| 75 | + |
| 76 | + // CHECK-DAG: and.b32 [[LANEID_FIXED_BITS:%.*]], [[TID]], 28; |
| 77 | + // CHECK-DAG: shl.b32 [[L0_TEMP:%.*]], [[L0_VAL]], 1; |
| 78 | + // CHECK-DAG: or.b32 [[LANEID_PART_PERM:%.*]], [[L0_TEMP]], [[LANEID_FIXED_BITS]]; |
| 79 | + // CHECK-DAG: bfe.u32 [[L1_TEMP:%.*]], [[TID]], 1, 1; |
| 80 | + // CHECK-DAG: or.b32 [[LANEID_PERM:%.*]], [[LANEID_PART_PERM]], [[L1_TEMP]]; |
| 81 | + |
| 82 | + // The index shuffles have source lane dependent on the value of the r2 bit. |
| 83 | + // Half of them use `LANEID_PERM` while the other half use `LANEID_PERM` |
| 84 | + // with the l0 bit flipped (step 2a). |
| 85 | + |
| 86 | + // CHECK-DAG: xor.b32 [[LANEID_PERM_F:%.*]], [[LANEID_PERM]], 1; |
| 87 | + |
| 88 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 89 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 90 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 91 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 92 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 93 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 94 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 95 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1; |
| 96 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 97 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 98 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 99 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 100 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 101 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 102 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 103 | + // CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1; |
| 104 | + |
| 105 | + // Finally, the last set of selects are performed, using the value of l1 as |
| 106 | + // the predicate (step 3). |
| 107 | + |
| 108 | + // CHECK-DAG: and.b32 [[L1_VAL:%.*]], [[TID]], 2; |
| 109 | + // CHECK-DAG: setp.eq.s32 [[L1_OFF:%.*]], [[L1_VAL]], 0; |
| 110 | + // CHECK-COUNT-16: selp.b32 {{.*}}, {{.*}}, [[L1_OFF]]; |
| 111 | + |
| 112 | + // CHECK-COUNT-64: bfe.u32 |
| 113 | + // CHECK-COUNT-64: st.volatile.global.b8 |
| 114 | + |
| 115 | + %0 = ttg.convert_layout %arg0 : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #dot_op> |
| 116 | + %1 = builtin.unrealized_conversion_cast %0 : tensor<128x64xf8E5M2, #dot_op> to !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> |
| 117 | + llvm.store volatile %1, %ptr : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>, !llvm.ptr |
| 118 | + |
| 119 | + tt.return |
| 120 | + } |
| 121 | +} |
0 commit comments