Skip to content

Commit aaecd1a

Browse files
committed
Tighten up backward pass staging.
1 parent 0313618 commit aaecd1a

File tree

5 files changed

+2568
-74
lines changed

5 files changed

+2568
-74
lines changed

bin/mfa/makefile

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ include ../../lib/config.mk
33
CFLAGS := -std=c++17 -O0 -g -Wall -I"../.." -I"../../lib" $(CFLAGS)
44
LDFLAGS := $(LDFLAGS) -framework QuartzCore
55

6-
TARGETS = gemm_scaffold na_gemm_splitk_bench na_int8_attention_bench na_int8_attention_backward_probe na_attention_backward_bench na_int8_matmul_bench dump_na_attention_source dump_na_int8_attention_source conv3d_layout_scaffold implicit_conv3d_scaffold gemm_kernel_introspect implicit_conv3d_bench conv3d_kernel_bench conv3d_branch_validate
6+
TARGETS = gemm_scaffold na_gemm_splitk_bench na_int8_attention_bench na_int8_attention_backward_probe na_attention_backward_bench sdpa_backward_compare_bench na_int8_matmul_bench dump_na_attention_source dump_na_int8_attention_source conv3d_layout_scaffold implicit_conv3d_scaffold gemm_kernel_introspect implicit_conv3d_bench conv3d_kernel_bench conv3d_branch_validate
7+
8+
MPSGRAPH_LDFLAGS := $(LDFLAGS) -framework Foundation -framework MetalPerformanceShaders -framework MetalPerformanceShadersGraph
79

810
COMMON_OBJS = \
911
Metal.local.o \
@@ -45,7 +47,7 @@ NA_BASELINE_ATTENTION_OBJS = \
4547
all: $(TARGETS)
4648

4749
clean:
48-
rm -f dump_na_source.o dump_na_attention_source.o dump_na_int8_attention_source.o gemm_scaffold.o na_gemm_splitk_bench.o na_int8_attention_bench.o na_int8_attention_backward_probe.o na_attention_backward_bench.o na_int8_matmul_bench.o conv3d_layout_scaffold.o implicit_conv3d_scaffold.o gemm_kernel_introspect.o implicit_conv3d_bench.o conv3d_kernel_bench.o conv3d_branch_validate.o $(COMMON_OBJS) $(NA_GEMM_OBJS) $(NA_CONV_OBJS) $(NA_ATTENTION_OBJS) $(NA_BASELINE_ATTENTION_OBJS) $(TARGETS) dump_na_source
50+
rm -f dump_na_source.o dump_na_attention_source.o dump_na_int8_attention_source.o gemm_scaffold.o na_gemm_splitk_bench.o na_int8_attention_bench.o na_int8_attention_backward_probe.o na_attention_backward_bench.o sdpa_backward_compare_bench.o na_int8_matmul_bench.o conv3d_layout_scaffold.o implicit_conv3d_scaffold.o gemm_kernel_introspect.o implicit_conv3d_bench.o conv3d_kernel_bench.o conv3d_branch_validate.o $(COMMON_OBJS) $(NA_GEMM_OBJS) $(NA_CONV_OBJS) $(NA_ATTENTION_OBJS) $(NA_BASELINE_ATTENTION_OBJS) $(TARGETS) dump_na_source
4951

5052
gemm_scaffold: gemm_scaffold.o $(COMMON_OBJS)
5153
$(CC) -o $@ $^ $(LDFLAGS)
@@ -58,6 +60,8 @@ na_int8_attention_backward_probe: na_int8_attention_backward_probe.o $(COMMON_OB
5860
$(CC) -o $@ $^ $(LDFLAGS)
5961
na_attention_backward_bench: na_attention_backward_bench.o $(COMMON_OBJS) $(NA_BASELINE_ATTENTION_OBJS)
6062
$(CC) -o $@ $^ $(LDFLAGS)
63+
sdpa_backward_compare_bench: sdpa_backward_compare_bench.o $(COMMON_OBJS) $(NA_ATTENTION_OBJS) $(NA_BASELINE_ATTENTION_OBJS)
64+
$(CC) -o $@ $^ $(MPSGRAPH_LDFLAGS)
6165
na_int8_matmul_bench: na_int8_matmul_bench.o $(COMMON_OBJS) $(NA_GEMM_OBJS)
6266
$(CC) -o $@ $^ $(LDFLAGS)
6367

@@ -99,6 +103,8 @@ na_int8_attention_backward_probe.o: na_int8_attention_backward_probe.cpp
99103
$(CC) $< -o $@ -c $(subst -std=c++17,-std=gnu++17,$(CFLAGS))
100104
na_attention_backward_bench.o: na_attention_backward_bench.cpp
101105
$(CC) $< -o $@ -c $(CFLAGS)
106+
sdpa_backward_compare_bench.o: sdpa_backward_compare_bench.mm
107+
$(CC) $< -o $@ -c $(subst -std=c++17,-std=gnu++17,$(CFLAGS))
102108
na_int8_matmul_bench.o: na_int8_matmul_bench.cpp
103109
$(CC) $< -o $@ -c $(CFLAGS)
104110

bin/mfa/na_attention_backward_bench.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ AttentionOperands<GEMMOperandPrecision> create_fp16_backward_precisions()
213213
memory_precisions[AttentionOperand::dQ] = GEMMOperandPrecision::FP16;
214214
memory_precisions[AttentionOperand::dK] = GEMMOperandPrecision::FP16;
215215
memory_precisions[AttentionOperand::dV] = GEMMOperandPrecision::FP16;
216-
memory_precisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
217-
memory_precisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
216+
memory_precisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
217+
memory_precisions[AttentionOperand::D] = GEMMOperandPrecision::BF16;
218218
return memory_precisions;
219219
}
220220

@@ -393,7 +393,7 @@ ForwardPipeline create_forward_pipeline(MTL::Device* device, const AttentionCase
393393
bundle.descriptor.Hk = attention.Hk;
394394
bundle.descriptor.lowPrecisionInputs = true;
395395
bundle.descriptor.isBF16 = false;
396-
bundle.descriptor.lowPrecisionIntermediates = false;
396+
bundle.descriptor.lowPrecisionIntermediates = true;
397397
bundle.descriptor.matrixDimensions = simd::uint3 { attention.R, attention.C, attention.D };
398398
bundle.descriptor.type = AttentionKernelType::forward;
399399
bundle.descriptor.scale = create_scale(attention);
@@ -431,7 +431,7 @@ BackwardPipelines create_backward_pipelines(
431431
bundle.query_descriptor.Hk = attention.Hk;
432432
bundle.query_descriptor.lowPrecisionInputs = true;
433433
bundle.query_descriptor.isBF16 = false;
434-
bundle.query_descriptor.lowPrecisionIntermediates = false;
434+
bundle.query_descriptor.lowPrecisionIntermediates = true;
435435
bundle.query_descriptor.matrixDimensions = simd::uint3 { attention.R, attention.C, attention.D };
436436
bundle.query_descriptor.type = AttentionKernelType::backwardQuery;
437437
bundle.query_descriptor.scale = create_scale(attention);
@@ -941,7 +941,7 @@ int main(int argc, char** argv)
941941
<< " blockC=" << forward_pipeline.kernel->blockDimensions[1]
942942
<< " blockD=" << forward_pipeline.kernel->blockDimensions[2]
943943
<< " simdgroups=" << forward_pipeline.kernel->executionSIMDGroups
944-
<< " lowPrecisionIntermediates=false"
944+
<< " lowPrecisionIntermediates=true"
945945
<< '\n';
946946
std::cout << "backward-kernel"
947947
<< " queryBlockR=" << backward_pipelines.query_kernel->blockDimensions[0]

0 commit comments

Comments
 (0)