-
Notifications
You must be signed in to change notification settings - Fork 0
Description
We have implemented a Triton kernel for matmul operations involving a telescoping cache in the telescoping-kernel branch. These kernels pass their respective correctness checks (also included), but deploying to our training pipeline is not straightforward because Triton does not support atomic-add in bf16 (see here).
We instead cast to fp16 before this op, but loss curves on a test llama3-1.8B model diverge when we do this:
Loss curves do not diverge when running the kernels in fp32. Unfortunately this sacrifices our speed gains. We're currently evaluating fp32 atomic-adds only, and will update here.
Running these matmuls in fp16 also breaks the vanilla pytorch code, so this is almost certainly a precision issue. If internal fp32 casting does not fix the diverging loss, can the kernel code be massaged to avoid these issues?
