Skip to content

Commit 0313618

Browse files
committed
Add initial backward kernel for MQA.
1 parent 9477f6d commit 0313618

13 files changed

+1778
-90
lines changed

AGENTS.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,20 @@ git checkout -- lib/nnc/cmd/ccv_nnc_cmd.inc lib/nnc/cmd/ccv_nnc_cmd.h lib/nnc/cm
174174
- temporarily force `use_neural_accelerators = 1` in `ccv_nnc_conv_mps.m`;
175175
- run `./mpsdnn.tests "mfa conv3d"` from `test/int/nnc`;
176176
- revert the force after validation so production code uses `ccv_nnc_mfa_has_neural_accelerators(context)`.
177+
- `NAInt8Attention` backward `dS` fallback note:
178+
- Earlier exploration suggested `dS -> half` might be a fallback worth keeping in mind, but on the current shipped `D=128` fixed-quant setup it is not a win.
179+
- Rechecked on `4096 x 4096 x 128` with the current selector:
180+
- fixed-quant `dS`: forward median `4.0495 ms`, backward median `21.8308 ms`, ratio `5.3910x`
181+
- `dS -> half`: forward median `4.0552 ms`, backward median `23.0083 ms`, ratio `5.6737x`
182+
- Takeaway:
183+
- on the current `NAInt8Attention` backward path, `dS -> half` regresses relative to fixed-quant `dS`
184+
- do not treat it as the preferred fallback without reworking the kernel again
185+
- `NAInt8Attention` backward fixed-quant selector note:
186+
- For the shipping `D=128` low-precision backward path, the safe production rule is:
187+
- query: `blockR=16`, `blockC=32`, `blockD=32`, `executionSIMDGroups=4`
188+
- key/value: `blockR=16`, `blockC=64`, `blockD=64`, `executionSIMDGroups=16`
189+
- Trust the backward absolute times more than any single reported ratio; forward medians on the probe can move enough to make one-off ratios look too optimistic.
190+
- Reliable current probe numbers are in this range:
191+
- `4096 x 4096 x 128`: backward median about `21-23 ms`, typically around `5.2x-5.6x`
192+
- `8192 x 8192 x 128`: backward median about `82-87 ms`, typically around `5.2x-5.4x`
193+
- Wider key/value traversal (`blockC=96`) can benchmark slightly faster in the probe but is not accuracy-safe on the real gradient test surface; keep `blockC=64` in production.

bin/mfa/makefile

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ include ../../lib/config.mk
33
CFLAGS := -std=c++17 -O0 -g -Wall -I"../.." -I"../../lib" $(CFLAGS)
44
LDFLAGS := $(LDFLAGS) -framework QuartzCore
55

6-
TARGETS = gemm_scaffold na_gemm_splitk_bench na_int8_attention_bench na_attention_backward_bench na_int8_matmul_bench dump_na_attention_source conv3d_layout_scaffold implicit_conv3d_scaffold gemm_kernel_introspect implicit_conv3d_bench conv3d_kernel_bench conv3d_branch_validate
6+
TARGETS = gemm_scaffold na_gemm_splitk_bench na_int8_attention_bench na_int8_attention_backward_probe na_attention_backward_bench na_int8_matmul_bench dump_na_attention_source dump_na_int8_attention_source conv3d_layout_scaffold implicit_conv3d_scaffold gemm_kernel_introspect implicit_conv3d_bench conv3d_kernel_bench conv3d_branch_validate
77

88
COMMON_OBJS = \
99
Metal.local.o \
@@ -45,7 +45,7 @@ NA_BASELINE_ATTENTION_OBJS = \
4545
all: $(TARGETS)
4646

4747
clean:
48-
rm -f dump_na_source.o dump_na_attention_source.o gemm_scaffold.o na_gemm_splitk_bench.o na_int8_attention_bench.o na_attention_backward_bench.o na_int8_matmul_bench.o conv3d_layout_scaffold.o implicit_conv3d_scaffold.o gemm_kernel_introspect.o implicit_conv3d_bench.o conv3d_kernel_bench.o conv3d_branch_validate.o $(COMMON_OBJS) $(NA_GEMM_OBJS) $(NA_CONV_OBJS) $(NA_ATTENTION_OBJS) $(NA_BASELINE_ATTENTION_OBJS) $(TARGETS) dump_na_source
48+
rm -f dump_na_source.o dump_na_attention_source.o dump_na_int8_attention_source.o gemm_scaffold.o na_gemm_splitk_bench.o na_int8_attention_bench.o na_int8_attention_backward_probe.o na_attention_backward_bench.o na_int8_matmul_bench.o conv3d_layout_scaffold.o implicit_conv3d_scaffold.o gemm_kernel_introspect.o implicit_conv3d_bench.o conv3d_kernel_bench.o conv3d_branch_validate.o $(COMMON_OBJS) $(NA_GEMM_OBJS) $(NA_CONV_OBJS) $(NA_ATTENTION_OBJS) $(NA_BASELINE_ATTENTION_OBJS) $(TARGETS) dump_na_source
4949

5050
gemm_scaffold: gemm_scaffold.o $(COMMON_OBJS)
5151
$(CC) -o $@ $^ $(LDFLAGS)
@@ -54,6 +54,8 @@ na_gemm_splitk_bench: na_gemm_splitk_bench.o $(COMMON_OBJS) $(NA_GEMM_OBJS)
5454
$(CC) -o $@ $^ $(LDFLAGS)
5555
na_int8_attention_bench: na_int8_attention_bench.o $(COMMON_OBJS) $(NA_ATTENTION_OBJS) $(NA_BASELINE_ATTENTION_OBJS)
5656
$(CC) -o $@ $^ $(LDFLAGS)
57+
na_int8_attention_backward_probe: na_int8_attention_backward_probe.o $(COMMON_OBJS) $(NA_ATTENTION_OBJS)
58+
$(CC) -o $@ $^ $(LDFLAGS)
5759
na_attention_backward_bench: na_attention_backward_bench.o $(COMMON_OBJS) $(NA_BASELINE_ATTENTION_OBJS)
5860
$(CC) -o $@ $^ $(LDFLAGS)
5961
na_int8_matmul_bench: na_int8_matmul_bench.o $(COMMON_OBJS) $(NA_GEMM_OBJS)
@@ -65,6 +67,9 @@ dump_na_source: dump_na_source.o $(COMMON_OBJS) NAMatMulKernelDescriptor.local.o
6567
dump_na_attention_source: dump_na_attention_source.o $(COMMON_OBJS) NAAttentionKernelDescriptor.local.o NAAttentionKernel.local.o NAAttentionDescriptor.local.o
6668
$(CC) -o $@ $^ $(LDFLAGS)
6769

70+
dump_na_int8_attention_source: dump_na_int8_attention_source.o $(COMMON_OBJS) NAInt8AttentionKernelDescriptor.local.o NAInt8AttentionKernel.local.o NAInt8AttentionDescriptor.local.o
71+
$(CC) -o $@ $^ $(LDFLAGS)
72+
6873
conv3d_layout_scaffold: conv3d_layout_scaffold.o $(COMMON_OBJS)
6974
$(CC) -o $@ $^ $(LDFLAGS)
7075

@@ -90,6 +95,8 @@ na_gemm_splitk_bench.o: na_gemm_splitk_bench.cpp
9095
$(CC) $< -o $@ -c $(CFLAGS)
9196
na_int8_attention_bench.o: na_int8_attention_bench.cpp
9297
$(CC) $< -o $@ -c $(CFLAGS)
98+
na_int8_attention_backward_probe.o: na_int8_attention_backward_probe.cpp
99+
$(CC) $< -o $@ -c $(subst -std=c++17,-std=gnu++17,$(CFLAGS))
93100
na_attention_backward_bench.o: na_attention_backward_bench.cpp
94101
$(CC) $< -o $@ -c $(CFLAGS)
95102
na_int8_matmul_bench.o: na_int8_matmul_bench.cpp
@@ -101,6 +108,9 @@ dump_na_source.o: dump_na_source.cpp
101108
dump_na_attention_source.o: dump_na_attention_source.cpp
102109
$(CC) $< -o $@ -c $(CFLAGS)
103110

111+
dump_na_int8_attention_source.o: dump_na_int8_attention_source.cpp
112+
$(CC) $< -o $@ -c $(CFLAGS)
113+
104114
conv3d_layout_scaffold.o: conv3d_layout_scaffold.cpp
105115
$(CC) $< -o $@ -c $(CFLAGS)
106116

@@ -173,6 +183,9 @@ NAInt8AttentionKernelDescriptor.local.o: ../../lib/nnc/mfa/kernels/NAInt8Attenti
173183
NAInt8AttentionKernel.local.o: ../../lib/nnc/mfa/kernels/NAInt8AttentionKernel.cpp
174184
$(CC) $< -o $@ -c $(CFLAGS)
175185

186+
NAInt8AttentionDescriptor.local.o: ../../lib/nnc/mfa/kernels/NAInt8AttentionDescriptor.cpp
187+
$(CC) $< -o $@ -c $(CFLAGS)
188+
176189
NAAttentionKernelDescriptor.local.o: ../../lib/nnc/mfa/kernels/NAAttentionKernelDescriptor.cpp
177190
$(CC) $< -o $@ -c $(CFLAGS)
178191

bin/mfa/na_int8_attention_bench.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,10 +1834,10 @@ double run_quantize_and_int8_once(
18341834
encoder->setBuffer(v_int8_buffer, 0, 2);
18351835
encoder->setBuffer(o_buffer, 0, 3);
18361836
encoder->setBuffer(l_buffer, 0, 4);
1837-
encoder->setBuffer(q_scale_buffer, 0, 5);
1838-
encoder->setBuffer(k_scale_buffer, 0, 6);
1839-
encoder->setBuffer(v_scale_buffer, 0, 7);
1840-
encoder->setBuffer(v_mean_buffer, 0, 8);
1837+
encoder->setBuffer(q_scale_buffer, 0, 10);
1838+
encoder->setBuffer(k_scale_buffer, 0, 11);
1839+
encoder->setBuffer(v_scale_buffer, 0, 12);
1840+
encoder->setBuffer(v_mean_buffer, 0, 14);
18411841
encoder->dispatchThreadgroups(
18421842
bundle.kernel->threadgroupsPerGrid(attention.batch, attention.R),
18431843
MTL::Size(bundle.kernel->threadgroupSize(bundle.pipeline.get()), 1, 1));
@@ -2122,10 +2122,10 @@ double run_int8_once(
21222122
encoder->setBuffer(v_buffer, 0, 2);
21232123
encoder->setBuffer(o_buffer, 0, 3);
21242124
encoder->setBuffer(l_buffer, 0, 4);
2125-
encoder->setBuffer(q_scale_buffer, 0, 5);
2126-
encoder->setBuffer(k_scale_buffer, 0, 6);
2127-
encoder->setBuffer(v_scale_buffer, 0, 7);
2128-
encoder->setBuffer(v_mean_buffer, 0, 8);
2125+
encoder->setBuffer(q_scale_buffer, 0, 10);
2126+
encoder->setBuffer(k_scale_buffer, 0, 11);
2127+
encoder->setBuffer(v_scale_buffer, 0, 12);
2128+
encoder->setBuffer(v_mean_buffer, 0, 14);
21292129
encoder->dispatchThreadgroups(
21302130
bundle.kernel->threadgroupsPerGrid(attention.batch, attention.R),
21312131
MTL::Size(bundle.kernel->threadgroupSize(bundle.pipeline.get()), 1, 1));

lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, c
612612
.masked = 0,
613613
.upcast = !is_downcast,
614614
.use_neural_accelerators = !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_NEURAL_ACCELERATORS) && ccv_nnc_mfa_has_neural_accelerators(context) && (mtl_data_type != 121 || ccv_nnc_mfa_neural_accelerators_support_bfloat(context)),
615-
.use_quantized_attention = 0,
615+
.use_quantized_attention = (cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_8I) != 0,
616616

617617
.batch_dims_q = { 0 },
618618
.batch_dims_mask = { 0 },

0 commit comments

Comments
 (0)