Skip to content

Commit 28e639e

Browse files
committed
Add: f16f32 WMMA variant for Ampere
1 parent 1359ca7 commit 28e639e

File tree

1 file changed

+98
-5
lines changed

1 file changed

+98
-5
lines changed

less_slow_sm80.ptx

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,13 @@
4141
* nicer by using the `<>` syntax to define many virtual registers without
4242
* explicitly naming them! We can also explicitly define them as `.f16x2` to
4343
* constrain the registers to packed half-precision pairs.
44+
*
45+
* We can also scale from a Quadpair-level MMA to the Warp-level WMMA,
46+
* synchronizing more threads to process larger tiles, as the PTX docs
47+
* explicitly warn against using the `mma.sync.m8n8k4` to avoid performance
48+
* issues!
4449
*/
45-
.visible .entry tops_f16f16_sm90tc_16x16x16_loop128_ptx_kernel()
50+
.visible .entry tops_f16f16_sm80wmma_16x16x16_loop128_ptx_kernel()
4651
{
4752
// Accumulator registers used for both input and output of the MMA operation
4853
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parameterized-variable-names
@@ -127,15 +132,103 @@ loop_exit:
127132
ret;
128133
}
129134

135+
.visible .entry tops_f16f32_sm80wmma_16x16x16_loop128_ptx_kernel()
136+
{
137+
// Accumulator registers used for both input and output of the MMA operation
138+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parameterized-variable-names
139+
.reg .b32 accum<8>;
140+
141+
// Registers to hold packed 16-bit data for matrix A (8 registers)
142+
.reg .f16x2 matrix_a<8>;
143+
144+
// Registers to hold packed 16-bit data for matrix B (8 registers)
145+
.reg .f16x2 matrix_b<8>;
146+
147+
// General-purpose registers for loop control and constant values
148+
.reg .b32 loop_counter, loop_limit, packed_const;
149+
150+
// Predicate register for conditional branching (loop exit)
151+
.reg .pred exit_predicate;
152+
153+
// Set up loop counter and loop limit
154+
mov.u32 loop_counter, 0;
155+
mov.u32 loop_limit, 128;
156+
157+
// Zero-initialize the accumulators, as registers may contain noise
158+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
159+
mov.f32 accum0, 0.0;
160+
mov.f32 accum1, 0.0;
161+
mov.f32 accum2, 0.0;
162+
mov.f32 accum3, 0.0;
163+
164+
// Initialize constant for packed matrix data (placeholder)
165+
mov.b32 packed_const, 0x00010001;
166+
167+
// Initialize matrix a registers with the packed constant
168+
mov.b32 matrix_a0, packed_const;
169+
mov.b32 matrix_a1, packed_const;
170+
mov.b32 matrix_a2, packed_const;
171+
mov.b32 matrix_a3, packed_const;
172+
mov.b32 matrix_a4, packed_const;
173+
mov.b32 matrix_a5, packed_const;
174+
mov.b32 matrix_a6, packed_const;
175+
mov.b32 matrix_a7, packed_const;
176+
177+
// Initialize matrix b registers with the packed constant
178+
mov.b32 matrix_b0, packed_const;
179+
mov.b32 matrix_b1, packed_const;
180+
mov.b32 matrix_b2, packed_const;
181+
mov.b32 matrix_b3, packed_const;
182+
mov.b32 matrix_b4, packed_const;
183+
mov.b32 matrix_b5, packed_const;
184+
mov.b32 matrix_b6, packed_const;
185+
mov.b32 matrix_b7, packed_const;
186+
187+
// The main loop will repeat for 128 iterations
188+
loop_start:
189+
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
190+
@exit_predicate bra loop_exit;
191+
192+
wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32
193+
{ accum0, accum1, accum2, accum3,
194+
accum4, accum5, accum6, accum7 },
195+
{ matrix_a0, matrix_a1, matrix_a2, matrix_a3,
196+
matrix_a4, matrix_a5, matrix_a6, matrix_a7 },
197+
{ matrix_b0, matrix_b1, matrix_b2, matrix_b3,
198+
matrix_b4, matrix_b5, matrix_b6, matrix_b7 },
199+
{ accum0, accum1, accum2, accum3,
200+
accum4, accum5, accum6, accum7 };
201+
202+
// Increment the loop counter
203+
add.u32 loop_counter, loop_counter, 1;
204+
205+
// Branch back to the beginning of the loop
206+
bra loop_start;
207+
208+
loop_exit:
209+
// This barrier forces all asynchronous warp-group operations to complete.
210+
bar.sync 0;
211+
212+
// Use volatile stores to force the accumulator values to be written out.
213+
// This dummy write (to a global variable) makes the work observable and
214+
// prevents the multiplication pipeline from being optimized out.
215+
st.global.volatile.f32 [dummy_sink_f32], accum0;
216+
st.global.volatile.f32 [dummy_sink_f32+4], accum1;
217+
st.global.volatile.f32 [dummy_sink_f32+8], accum2;
218+
st.global.volatile.f32 [dummy_sink_f32+12], accum3;
219+
ret;
220+
}
221+
130222
/**
131223
* Each new generation of Tensor Cores supports a wider palette of numeric
132224
* types, "structured sparsity" modes, and asynchronous scheduling protocols.
133225
*
134-
* For double-precision numbers, we can go down to a granularity as small as
135-
* just 8x8x4 for `sm_80` or higher.
226+
* ! For double-precision numbers, the smallest granularity is 8x8x4.
227+
* ! Technically, it requires SM 8.0, but it's not a Warp-level MMA operation.
228+
* ! It's Quadpair-level MMA operation!
136229
*/
137230

138-
.visible .entry tops_f64f64_sm90tc_8x8x4_loop128_ptx_kernel()
231+
.visible .entry tops_f64f64_sm80mma_8x8x4_loop128_ptx_kernel()
139232
{
140233
// Registers to hold matrix A and B operands (each a single f64)
141234
.reg .f64 matrix_a, matrix_b;
@@ -209,7 +302,7 @@ loop_exit:
209302
* is confusingly 19 bits wide! The synchronous variant would look familiar:
210303
*/
211304

212-
.visible .entry tops_tf32f32_sm90tc_16x16x8_loop128_ptx_kernel()
305+
.visible .entry tops_tf32f32_sm80wmma_16x16x8_loop128_ptx_kernel()
213306
{
214307
// Accumulator registers used for both input and output of the MMA operation
215308
.reg .b32 accum<8>;

0 commit comments

Comments
 (0)