@@ -58,4 +58,105 @@ $$p'(x) = norm(max(0, p(x) - q(x)))$$
5858
5959In the paper [ @leviathan2023fast ] ,
6060Leviathan et al. have proved the correctness of this adjusted
61- distribution for resampling.
61+ distribution for resampling.
62+
63+ Under the assumption that the execution time for a single step of the
64+ Target model is denoted as $T$, and that of the draft model as $cT$,
65+ where $0<c\leq1$. The standard procedure using the target model to
66+ generate $\gamma + 1$ tokens would require a total time of
67+ $\gamma T + T$. In contrast, with speculative decoding, where
68+ $\gamma + 1$ tokens are produced ($\gamma$ by the draft model and one
69+ additional by the target model concurrently during the parallel
70+ verification), the time required would be $\gamma cT + T$. If all
71+ $\gamma$ draft tokens are accepted by the target model and $c$ is small
72+ enough to make $cT << T$, speculative decoding has the potential to
73+ significantly reduce latency during the decoding process.
74+
75+ To further explain, if we denote $\alpha = E(\beta)$ where $\beta$ is
76+ the acceptance rate with a given prefix and $E(\beta)$ is a natural
77+ measure of how well the draft model can approximate the target model
78+ assuming $\beta$s are i.i.d., the expected number of tokens generated by
79+ the speculative process is $\frac{1-\alpha^{\gamma+1}}{1-\alpha}$
80+ [ @leviathan2023fast ] . According to the speculative decoding time for one
81+ superstep $\gamma cT + T$, the expected time for generating one token
82+ with speculative decoding is
83+ $\frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T$. By choosing a good
84+ $\gamma$ and a well-aligned efficient draft model meaning big $\alpha$
85+ and small $c$, the result is desired.
86+
87+ Nevertheless, as the value of $\gamma$ continues to rise, it becomes
88+ progressively more difficult for a draft model to generate draft tokens
89+ with a high acceptance rate by the target model, especially as the
90+ likelihood of acceptance typically diminishes when $\gamma$ exceeds a
91+ certain value. In the worst-case scenario, if all draft tokens generated
92+ by the draft model are rejected by the target model, then only the one
93+ token that is resampled from the adjusted distribution will be decoded
94+ following the speculative process. In this situation, the time spent on
95+ generating $\gamma$ tokens with the draft model represented as
96+ $\gamma cT$ effectively becomes a complete waste of time when compared
97+ to generating a single token directly with the target model; in
98+ addition, the draft model is consuming the GPU memory.
99+
100+ Therefore, finding the best $\gamma$ or having a well-designed draft
101+ model that is effectively accepted by the target model is of importance.
102+ There are some strategies that can be employed to address this issue
103+ effectively. For example:
104+
105+ ** Self-Derived Drafts from Target Models**
106+
107+ Is it possible to utilize the target model directly as the draft model,
108+ rather than employing a separate smaller model, which could lead to
109+ increased GPU memory usage? The answer is yes. Similar to the original
110+ approach, the modification involves switching the draft model into the
111+ target model itself, followed by self-verifying these draft tokens. The
112+ advantages of this method are:
113+
114+ 1 . Since the draft model is almost the same as the target model, it is
115+ sufficiently robust to maintain a stable acceptance rate.
116+
117+ 2 . Only need to keep one model in the GPU memory.
118+
119+ The challenge now lies in the ability to generate multiple future tokens
120+ in a single decoding step. To achieve this, the concept involves
121+ appending additional concurrent layers to the existing output layer of
122+ the model. Stern et al. first proposed this method in
123+ [ @stern2018blockwise ] .
124+
125+ The training of these extra layers can either start from scratch with
126+ the target model or involve fine-tuning a pre-trained model. This
127+ approach forms the basis of the Medusa [ @medusa ] . Medusa's architecture
128+ includes extra \" Medusa heads\" attached after the last hidden layer.
129+ This design enables the model to generate a range of token candidates in
130+ just one decoding step. Subsequently, these candidates undergo a
131+ self-verification process, and only the accepted tokens are executed.
132+
133+ Other methodologies, such as implementing Knowledge Distillation between
134+ draft and target models, employing multiple draft models instead of just
135+ one, and replacing draft models with retrieval datasets proposed by
136+ researchers are still being investigated to determine their
137+ effectiveness and reliability.
138+
139+ Speculative decoding is an effective technique that uses smaller models
140+ to reduce the overhead caused by larger models. By developing a
141+ well-trained and aligned draft model, the efficiency of the decoding
142+ process can be significantly improved.
143+
144+ ## FlashAttention
145+
146+ FlashAttention is an advanced optimization technique utilizing the
147+ memory hierarchy aimed at enhancing the efficiency of attention
148+ computations in transformer models in terms of memory usage and speed.
149+
150+ Dao et al. were the first to suggest this approach, as indicated in
151+ [ @dao2022flashattention ] . They noted the absence of * IO-awareness* --
152+ the consideration of I/O interactions across GPU memory layers -- in the
153+ classic Scaled Dot-Product Attention algorithm. To address this, they
154+ introduced FlashAttention, an enhanced version of the attention
155+ algorithm designed to minimize the intensive access to the GPU's high
156+ bandwidth memory (HBM). This innovation led to significant gains in both
157+ computational speed and throughput.
158+
159+ Figure :numref:` ch-deploy/memory ` shows the memory hierarchy with
160+ corresponding bandwidths. The main goal of FlashAttention is to avoid
161+ reading and writing the large attention matrix to and from HBM. And
162+ perform computation in SRAM as much as possible.
0 commit comments