|
41 | 41 | * nicer by using the `<>` syntax to define many virtual registers without |
42 | 42 | * explicitly naming them! We can also explicitly define them as `.f16x2` to |
43 | 43 | * constrain the registers to packed half-precision pairs. |
| 44 | + * |
| 45 | + * We can also scale from a Quadpair-level MMA to the Warp-level WMMA, |
| 46 | + * synchronizing more threads to process larger tiles, as the PTX docs |
| 47 | + * explicitly warn against using the `mma.sync.m8n8k4` to avoid performance |
| 48 | + * issues! |
44 | 49 | */ |
45 | | -.visible .entry tops_f16f16_sm90tc_16x16x16_loop128_ptx_kernel() |
| 50 | +.visible .entry tops_f16f16_sm80wmma_16x16x16_loop128_ptx_kernel() |
46 | 51 | { |
47 | 52 | // Accumulator registers used for both input and output of the MMA operation |
48 | 53 | // https://docs.nvidia.com/cuda/parallel-thread-execution/#parameterized-variable-names |
@@ -127,15 +132,103 @@ loop_exit: |
127 | 132 | ret; |
128 | 133 | } |
129 | 134 |
|
| 135 | +.visible .entry tops_f16f32_sm80wmma_16x16x16_loop128_ptx_kernel() |
| 136 | +{ |
| 137 | + // Accumulator registers used for both input and output of the MMA operation |
| 138 | + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parameterized-variable-names |
| 139 | + .reg .b32 accum<8>; |
| 140 | + |
| 141 | + // Registers to hold packed 16-bit data for matrix A (8 registers) |
| 142 | + .reg .f16x2 matrix_a<8>; |
| 143 | + |
| 144 | + // Registers to hold packed 16-bit data for matrix B (8 registers) |
| 145 | + .reg .f16x2 matrix_b<8>; |
| 146 | + |
| 147 | + // General-purpose registers for loop control and constant values |
| 148 | + .reg .b32 loop_counter, loop_limit, packed_const; |
| 149 | + |
| 150 | + // Predicate register for conditional branching (loop exit) |
| 151 | + .reg .pred exit_predicate; |
| 152 | + |
| 153 | + // Set up loop counter and loop limit |
| 154 | + mov.u32 loop_counter, 0; |
| 155 | + mov.u32 loop_limit, 128; |
| 156 | + |
| 157 | + // Zero-initialize the accumulators, as registers may contain noise |
| 158 | + // https://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces |
| 159 | + mov.f32 accum0, 0.0; |
| 160 | + mov.f32 accum1, 0.0; |
| 161 | + mov.f32 accum2, 0.0; |
| 162 | + mov.f32 accum3, 0.0; |
| 163 | + |
| 164 | + // Initialize constant for packed matrix data (placeholder) |
| 165 | + mov.b32 packed_const, 0x00010001; |
| 166 | + |
| 167 | + // Initialize matrix a registers with the packed constant |
| 168 | + mov.b32 matrix_a0, packed_const; |
| 169 | + mov.b32 matrix_a1, packed_const; |
| 170 | + mov.b32 matrix_a2, packed_const; |
| 171 | + mov.b32 matrix_a3, packed_const; |
| 172 | + mov.b32 matrix_a4, packed_const; |
| 173 | + mov.b32 matrix_a5, packed_const; |
| 174 | + mov.b32 matrix_a6, packed_const; |
| 175 | + mov.b32 matrix_a7, packed_const; |
| 176 | + |
| 177 | + // Initialize matrix b registers with the packed constant |
| 178 | + mov.b32 matrix_b0, packed_const; |
| 179 | + mov.b32 matrix_b1, packed_const; |
| 180 | + mov.b32 matrix_b2, packed_const; |
| 181 | + mov.b32 matrix_b3, packed_const; |
| 182 | + mov.b32 matrix_b4, packed_const; |
| 183 | + mov.b32 matrix_b5, packed_const; |
| 184 | + mov.b32 matrix_b6, packed_const; |
| 185 | + mov.b32 matrix_b7, packed_const; |
| 186 | + |
| 187 | + // The main loop will repeat for 128 iterations |
| 188 | +loop_start: |
| 189 | + setp.ge.u32 exit_predicate, loop_counter, loop_limit; |
| 190 | + @exit_predicate bra loop_exit; |
| 191 | + |
| 192 | + wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32 |
| 193 | + { accum0, accum1, accum2, accum3, |
| 194 | + accum4, accum5, accum6, accum7 }, |
| 195 | + { matrix_a0, matrix_a1, matrix_a2, matrix_a3, |
| 196 | + matrix_a4, matrix_a5, matrix_a6, matrix_a7 }, |
| 197 | + { matrix_b0, matrix_b1, matrix_b2, matrix_b3, |
| 198 | + matrix_b4, matrix_b5, matrix_b6, matrix_b7 }, |
| 199 | + { accum0, accum1, accum2, accum3, |
| 200 | + accum4, accum5, accum6, accum7 }; |
| 201 | + |
| 202 | + // Increment the loop counter |
| 203 | + add.u32 loop_counter, loop_counter, 1; |
| 204 | + |
| 205 | + // Branch back to the beginning of the loop |
| 206 | + bra loop_start; |
| 207 | + |
| 208 | +loop_exit: |
| 209 | + // This barrier forces all asynchronous warp-group operations to complete. |
| 210 | + bar.sync 0; |
| 211 | + |
| 212 | + // Use volatile stores to force the accumulator values to be written out. |
| 213 | + // This dummy write (to a global variable) makes the work observable and |
| 214 | + // prevents the multiplication pipeline from being optimized out. |
| 215 | + st.global.volatile.f32 [dummy_sink_f32], accum0; |
| 216 | + st.global.volatile.f32 [dummy_sink_f32+4], accum1; |
| 217 | + st.global.volatile.f32 [dummy_sink_f32+8], accum2; |
| 218 | + st.global.volatile.f32 [dummy_sink_f32+12], accum3; |
| 219 | + ret; |
| 220 | +} |
| 221 | + |
130 | 222 | /** |
131 | 223 | * Each new generation of Tensor Cores supports a wider palette of numeric |
132 | 224 | * types, "structured sparsity" modes, and asynchronous scheduling protocols. |
133 | 225 | * |
134 | | - * For double-precision numbers, we can go down to a granularity as small as |
135 | | - * just 8x8x4 for `sm_80` or higher. |
| 226 | + * ! For double-precision numbers, the smallest granularity is 8x8x4. |
| 227 | + * ! Technically, it requires SM 8.0, but it's not a Warp-level MMA operation. |
| 228 | + * ! It's Quadpair-level MMA operation! |
136 | 229 | */ |
137 | 230 |
|
138 | | -.visible .entry tops_f64f64_sm90tc_8x8x4_loop128_ptx_kernel() |
| 231 | +.visible .entry tops_f64f64_sm80mma_8x8x4_loop128_ptx_kernel() |
139 | 232 | { |
140 | 233 | // Registers to hold matrix A and B operands (each a single f64) |
141 | 234 | .reg .f64 matrix_a, matrix_b; |
@@ -209,7 +302,7 @@ loop_exit: |
209 | 302 | * is confusingly 19 bits wide! The synchronous variant would look familiar: |
210 | 303 | */ |
211 | 304 |
|
212 | | - .visible .entry tops_tf32f32_sm90tc_16x16x8_loop128_ptx_kernel() |
| 305 | + .visible .entry tops_tf32f32_sm80wmma_16x16x8_loop128_ptx_kernel() |
213 | 306 | { |
214 | 307 | // Accumulator registers used for both input and output of the MMA operation |
215 | 308 | .reg .b32 accum<8>; |
|
0 commit comments