Skip to content

Commit 1359ca7

Browse files
committed
Add: f16f32 MMA variant for Volta
1 parent 6e16165 commit 1359ca7

File tree

1 file changed

+93
-23
lines changed

1 file changed

+93
-23
lines changed

less_slow_sm70.ptx

Lines changed: 93 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,95 @@
3737
.target sm_70 // Target architecture (SM 7.0 - Volta GPUs)
3838
.address_size 64 // 64-bit addressing
3939

40-
.visible .entry tops_f16f16_sm70tc_16x16x16_loop128_ptx_kernel()
40+
.visible .entry tops_f16f16_sm70mma_8x8x4_loop128_ptx_kernel()
4141
{
4242
// Accumulator registers used for both input and output of the MMA operation
4343
.reg .b32 accum_0, accum_1, accum_2, accum_3;
4444

45-
// Registers to hold packed 16-bit data for matrix a (8 registers)
46-
.reg .b32 matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3,
47-
matrix_a_4, matrix_a_5, matrix_a_6, matrix_a_7;
45+
// Registers to hold packed pairs of 16-bit data for matrix a (2 registers)
46+
.reg .b32 matrix_a_0, matrix_a_1;
4847

49-
// Registers to hold packed 16-bit data for matrix b (8 registers)
50-
.reg .b32 matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3,
51-
matrix_b_4, matrix_b_5, matrix_b_6, matrix_b_7;
48+
// Registers to hold packed pairs of 16-bit data for matrix b (2 registers)
49+
.reg .b32 matrix_b_0, matrix_b_1;
50+
51+
// General-purpose registers for loop control and constant values
52+
.reg .b32 loop_counter, loop_limit, packed_const;
53+
54+
// Predicate register for conditional branching (loop exit)
55+
.reg .pred exit_predicate;
56+
57+
// Set up loop counter and loop limit
58+
mov.u32 loop_counter, 0;
59+
mov.u32 loop_limit, 128;
60+
61+
// Zero-initialize the accumulator registers
62+
mov.f32 accum_0, 0.0;
63+
mov.f32 accum_1, 0.0;
64+
mov.f32 accum_2, 0.0;
65+
mov.f32 accum_3, 0.0;
66+
67+
// Initialize constant for packed matrix data (placeholder)
68+
mov.b32 packed_const, 0x00010001;
69+
70+
// Initialize matrix a registers with the packed constant
71+
mov.b32 matrix_a_0, packed_const;
72+
mov.b32 matrix_a_1, packed_const;
73+
74+
// Initialize matrix b registers with the packed constant
75+
mov.b32 matrix_b_0, packed_const;
76+
mov.b32 matrix_b_1, packed_const;
77+
78+
// The main loop will repeat for 128 iterations
79+
loop_start:
80+
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
81+
@exit_predicate bra loop_end;
82+
83+
mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16
84+
{ accum_0, accum_1, accum_2, accum_3 },
85+
{ matrix_a_0, matrix_a_1 },
86+
{ matrix_b_0, matrix_b_1 },
87+
{ accum_0, accum_1, accum_2, accum_3 };
88+
89+
// Increment the loop counter
90+
add.u32 loop_counter, loop_counter, 1;
91+
92+
// Branch back to the beginning of the loop
93+
bra loop_start;
94+
95+
loop_end:
96+
// If we simply exit, the computation will be optimized out!
97+
// Instead, let's check for an impossible condition, like if the thread ID
98+
// is equal to `UINT_MAX`, and if so - write accumulators to global memory
99+
// NULL address.
100+
.reg .u32 tid;
101+
.reg .pred impossible_predicate;
102+
mov.u32 tid, %tid.x; //? Special system registers start with `%`
103+
setp.ne.u32 impossible_predicate, tid, 0xFFFFFFFF;
104+
@impossible_predicate bra loop_exit;
105+
106+
// Write into memory:
107+
.reg .u64 store_ptr;
108+
mov.u64 store_ptr, 0;
109+
st.global.f32 [store_ptr], accum_0;
110+
st.global.f32 [store_ptr+4], accum_1;
111+
st.global.f32 [store_ptr+8], accum_2;
112+
st.global.f32 [store_ptr+12], accum_3;
113+
114+
loop_exit:
115+
ret;
116+
}
117+
118+
.visible .entry tops_f16f32_sm70mma_8x8x4_loop128_ptx_kernel()
119+
{
120+
// Accumulator registers used for both input and output of the MMA operation
121+
.reg .b32 accum_0, accum_1, accum_2, accum_3,
122+
accum_4, accum_5, accum_6, accum_7;
123+
124+
// Registers to hold packed 16-bit data for matrix a (4 registers)
125+
.reg .b32 matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3;
126+
127+
// Registers to hold packed 16-bit data for matrix b (4 registers)
128+
.reg .b32 matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3;
52129

53130
// General-purpose registers for loop control and constant values
54131
.reg .b32 loop_counter, loop_limit, packed_const;
@@ -74,33 +151,25 @@
74151
mov.b32 matrix_a_1, packed_const;
75152
mov.b32 matrix_a_2, packed_const;
76153
mov.b32 matrix_a_3, packed_const;
77-
mov.b32 matrix_a_4, packed_const;
78-
mov.b32 matrix_a_5, packed_const;
79-
mov.b32 matrix_a_6, packed_const;
80-
mov.b32 matrix_a_7, packed_const;
81154

82155
// Initialize matrix b registers with the packed constant
83156
mov.b32 matrix_b_0, packed_const;
84157
mov.b32 matrix_b_1, packed_const;
85158
mov.b32 matrix_b_2, packed_const;
86159
mov.b32 matrix_b_3, packed_const;
87-
mov.b32 matrix_b_4, packed_const;
88-
mov.b32 matrix_b_5, packed_const;
89-
mov.b32 matrix_b_6, packed_const;
90-
mov.b32 matrix_b_7, packed_const;
91160

92161
// The main loop will repeat for 128 iterations
93162
loop_start:
94163
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
95164
@exit_predicate bra loop_end;
96165

97-
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
98-
{ accum_0, accum_1, accum_2, accum_3 },
99-
{ matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3,
100-
matrix_a_4, matrix_a_5, matrix_a_6, matrix_a_7 },
101-
{ matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3,
102-
matrix_b_4, matrix_b_5, matrix_b_6, matrix_b_7 },
103-
{ accum_0, accum_1, accum_2, accum_3 };
166+
mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32
167+
{ accum_0, accum_1, accum_2, accum_3,
168+
accum_4, accum_5, accum_6, accum_7 },
169+
{ matrix_a_0, matrix_a_1 },
170+
{ matrix_b_0, matrix_b_1 },
171+
{ accum_0, accum_1, accum_2, accum_3,
172+
accum_4, accum_5, accum_6, accum_7 };
104173

105174
// Increment the loop counter
106175
add.u32 loop_counter, loop_counter, 1;
@@ -147,7 +216,8 @@ loop_exit:
147216
* with both arguments in shared memory!
148217
*
149218
* Because only one `.version` directive can be placed in each file, for newer
150-
* kernels, go to `less_slow_sm90a.ptx`.
219+
* kernels, go to `less_slow_sm80.ptx` for Ampere and `less_slow_sm90a.ptx`
220+
* for Hopper.
151221
*
152222
* @see PTX module-level directives:
153223
* https://docs.nvidia.com/cuda/parallel-thread-execution/#ptx-module-directives

0 commit comments

Comments
 (0)