|
37 | 37 | .target sm_70 // Target architecture (SM 7.0 - Volta GPUs) |
38 | 38 | .address_size 64 // 64-bit addressing |
39 | 39 |
|
40 | | -.visible .entry tops_f16f16_sm70tc_16x16x16_loop128_ptx_kernel() |
| 40 | +.visible .entry tops_f16f16_sm70mma_8x8x4_loop128_ptx_kernel() |
41 | 41 | { |
42 | 42 | // Accumulator registers used for both input and output of the MMA operation |
43 | 43 | .reg .b32 accum_0, accum_1, accum_2, accum_3; |
44 | 44 |
|
45 | | - // Registers to hold packed 16-bit data for matrix a (8 registers) |
46 | | - .reg .b32 matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3, |
47 | | - matrix_a_4, matrix_a_5, matrix_a_6, matrix_a_7; |
| 45 | + // Registers to hold packed pairs of 16-bit data for matrix a (2 registers) |
| 46 | + .reg .b32 matrix_a_0, matrix_a_1; |
48 | 47 |
|
49 | | - // Registers to hold packed 16-bit data for matrix b (8 registers) |
50 | | - .reg .b32 matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3, |
51 | | - matrix_b_4, matrix_b_5, matrix_b_6, matrix_b_7; |
| 48 | + // Registers to hold packed pairs of 16-bit data for matrix b (2 registers) |
| 49 | + .reg .b32 matrix_b_0, matrix_b_1; |
| 50 | + |
| 51 | + // General-purpose registers for loop control and constant values |
| 52 | + .reg .b32 loop_counter, loop_limit, packed_const; |
| 53 | + |
| 54 | + // Predicate register for conditional branching (loop exit) |
| 55 | + .reg .pred exit_predicate; |
| 56 | + |
| 57 | + // Set up loop counter and loop limit |
| 58 | + mov.u32 loop_counter, 0; |
| 59 | + mov.u32 loop_limit, 128; |
| 60 | + |
| 61 | + // Zero-initialize the accumulator registers |
| 62 | + mov.f32 accum_0, 0.0; |
| 63 | + mov.f32 accum_1, 0.0; |
| 64 | + mov.f32 accum_2, 0.0; |
| 65 | + mov.f32 accum_3, 0.0; |
| 66 | + |
| 67 | + // Initialize constant for packed matrix data (placeholder) |
| 68 | + mov.b32 packed_const, 0x00010001; |
| 69 | + |
| 70 | + // Initialize matrix a registers with the packed constant |
| 71 | + mov.b32 matrix_a_0, packed_const; |
| 72 | + mov.b32 matrix_a_1, packed_const; |
| 73 | + |
| 74 | + // Initialize matrix b registers with the packed constant |
| 75 | + mov.b32 matrix_b_0, packed_const; |
| 76 | + mov.b32 matrix_b_1, packed_const; |
| 77 | + |
| 78 | + // The main loop will repeat for 128 iterations |
| 79 | +loop_start: |
| 80 | + setp.ge.u32 exit_predicate, loop_counter, loop_limit; |
| 81 | + @exit_predicate bra loop_end; |
| 82 | + |
| 83 | + mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 |
| 84 | + { accum_0, accum_1, accum_2, accum_3 }, |
| 85 | + { matrix_a_0, matrix_a_1 }, |
| 86 | + { matrix_b_0, matrix_b_1 }, |
| 87 | + { accum_0, accum_1, accum_2, accum_3 }; |
| 88 | + |
| 89 | + // Increment the loop counter |
| 90 | + add.u32 loop_counter, loop_counter, 1; |
| 91 | + |
| 92 | + // Branch back to the beginning of the loop |
| 93 | + bra loop_start; |
| 94 | + |
| 95 | +loop_end: |
| 96 | + // If we simply exit, the computation will be optimized out! |
| 97 | + // Instead, let's check for an impossible condition, like if the thread ID |
| 98 | + // is equal to `UINT_MAX`, and if so - write accumulators to global memory |
| 99 | + // NULL address. |
| 100 | + .reg .u32 tid; |
| 101 | + .reg .pred impossible_predicate; |
| 102 | + mov.u32 tid, %tid.x; //? Special system registers start with `%` |
| 103 | + setp.ne.u32 impossible_predicate, tid, 0xFFFFFFFF; |
| 104 | + @impossible_predicate bra loop_exit; |
| 105 | + |
| 106 | + // Write into memory: |
| 107 | + .reg .u64 store_ptr; |
| 108 | + mov.u64 store_ptr, 0; |
| 109 | + st.global.f32 [store_ptr], accum_0; |
| 110 | + st.global.f32 [store_ptr+4], accum_1; |
| 111 | + st.global.f32 [store_ptr+8], accum_2; |
| 112 | + st.global.f32 [store_ptr+12], accum_3; |
| 113 | + |
| 114 | +loop_exit: |
| 115 | + ret; |
| 116 | +} |
| 117 | + |
| 118 | +.visible .entry tops_f16f32_sm70mma_8x8x4_loop128_ptx_kernel() |
| 119 | +{ |
| 120 | + // Accumulator registers used for both input and output of the MMA operation |
| 121 | + .reg .b32 accum_0, accum_1, accum_2, accum_3, |
| 122 | + accum_4, accum_5, accum_6, accum_7; |
| 123 | + |
| 124 | + // Registers to hold packed 16-bit data for matrix a (4 registers) |
| 125 | + .reg .b32 matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3; |
| 126 | + |
| 127 | + // Registers to hold packed 16-bit data for matrix b (4 registers) |
| 128 | + .reg .b32 matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3; |
52 | 129 |
|
53 | 130 | // General-purpose registers for loop control and constant values |
54 | 131 | .reg .b32 loop_counter, loop_limit, packed_const; |
|
74 | 151 | mov.b32 matrix_a_1, packed_const; |
75 | 152 | mov.b32 matrix_a_2, packed_const; |
76 | 153 | mov.b32 matrix_a_3, packed_const; |
77 | | - mov.b32 matrix_a_4, packed_const; |
78 | | - mov.b32 matrix_a_5, packed_const; |
79 | | - mov.b32 matrix_a_6, packed_const; |
80 | | - mov.b32 matrix_a_7, packed_const; |
81 | 154 |
|
82 | 155 | // Initialize matrix b registers with the packed constant |
83 | 156 | mov.b32 matrix_b_0, packed_const; |
84 | 157 | mov.b32 matrix_b_1, packed_const; |
85 | 158 | mov.b32 matrix_b_2, packed_const; |
86 | 159 | mov.b32 matrix_b_3, packed_const; |
87 | | - mov.b32 matrix_b_4, packed_const; |
88 | | - mov.b32 matrix_b_5, packed_const; |
89 | | - mov.b32 matrix_b_6, packed_const; |
90 | | - mov.b32 matrix_b_7, packed_const; |
91 | 160 |
|
92 | 161 | // The main loop will repeat for 128 iterations |
93 | 162 | loop_start: |
94 | 163 | setp.ge.u32 exit_predicate, loop_counter, loop_limit; |
95 | 164 | @exit_predicate bra loop_end; |
96 | 165 |
|
97 | | - wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 |
98 | | - { accum_0, accum_1, accum_2, accum_3 }, |
99 | | - { matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3, |
100 | | - matrix_a_4, matrix_a_5, matrix_a_6, matrix_a_7 }, |
101 | | - { matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3, |
102 | | - matrix_b_4, matrix_b_5, matrix_b_6, matrix_b_7 }, |
103 | | - { accum_0, accum_1, accum_2, accum_3 }; |
| 166 | + mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 |
| 167 | + { accum_0, accum_1, accum_2, accum_3, |
| 168 | + accum_4, accum_5, accum_6, accum_7 }, |
| 169 | + { matrix_a_0, matrix_a_1 }, |
| 170 | + { matrix_b_0, matrix_b_1 }, |
| 171 | + { accum_0, accum_1, accum_2, accum_3, |
| 172 | + accum_4, accum_5, accum_6, accum_7 }; |
104 | 173 |
|
105 | 174 | // Increment the loop counter |
106 | 175 | add.u32 loop_counter, loop_counter, 1; |
@@ -147,7 +216,8 @@ loop_exit: |
147 | 216 | * with both arguments in shared memory! |
148 | 217 | * |
149 | 218 | * Because only one `.version` directive can be placed in each file, for newer |
150 | | - * kernels, go to `less_slow_sm90a.ptx`. |
| 219 | + * kernels, go to `less_slow_sm80.ptx` for Ampere and `less_slow_sm90a.ptx` |
| 220 | + * for Hopper. |
151 | 221 | * |
152 | 222 | * @see PTX module-level directives: |
153 | 223 | * https://docs.nvidia.com/cuda/parallel-thread-execution/#ptx-module-directives |
|
0 commit comments