Skip to content

Commit d5768ba

Browse files
committed
LoRA typo fix
1 parent 9dd97ff commit d5768ba

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

docs/lora/index.html

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,7 @@ <h2>LoRA Embedding Layer</h2>
361361
</div>
362362
<div class='code'>
363363
<div class="highlight"><pre><span class="lineno">123</span> <span class="k">if</span> <span class="n">alpha</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
364-
<span class="lineno">124</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">r</span>
365-
<span class="lineno">125</span>
366-
<span class="lineno">126</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span></pre></div>
364+
<span class="lineno">124</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">r</span></pre></div>
367365
</div>
368366
</div>
369367
<div class='section' id='section-22'>
@@ -375,8 +373,8 @@ <h2>LoRA Embedding Layer</h2>
375373

376374
</div>
377375
<div class='code'>
378-
<div class="highlight"><pre><span class="lineno">128</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">num_embeddings</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">)))</span>
379-
<span class="lineno">129</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">requires_grad</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
376+
<div class="highlight"><pre><span class="lineno">127</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">num_embeddings</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">)))</span>
377+
<span class="lineno">128</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">requires_grad</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
380378
</div>
381379
</div>
382380
<div class='section' id='section-23'>
@@ -388,7 +386,7 @@ <h2>LoRA Embedding Layer</h2>
388386

389387
</div>
390388
<div class='code'>
391-
<div class="highlight"><pre><span class="lineno">132</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaling</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">/</span> <span class="n">r</span></pre></div>
389+
<div class="highlight"><pre><span class="lineno">131</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaling</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">/</span> <span class="n">r</span></pre></div>
392390
</div>
393391
</div>
394392
<div class='section' id='section-24'>
@@ -400,7 +398,7 @@ <h2>LoRA Embedding Layer</h2>
400398

401399
</div>
402400
<div class='code'>
403-
<div class="highlight"><pre><span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_a</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">r</span><span class="p">,</span> <span class="n">num_embeddings</span><span class="p">)))</span></pre></div>
401+
<div class="highlight"><pre><span class="lineno">133</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_a</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">r</span><span class="p">,</span> <span class="n">num_embeddings</span><span class="p">)))</span></pre></div>
404402
</div>
405403
</div>
406404
<div class='section' id='section-25'>
@@ -412,9 +410,9 @@ <h2>LoRA Embedding Layer</h2>
412410

413411
</div>
414412
<div class='code'>
415-
<div class="highlight"><pre><span class="lineno">136</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_b</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">embedding_dim</span><span class="p">,</span> <span class="n">r</span><span class="p">)))</span>
416-
<span class="lineno">137</span>
417-
<span class="lineno">138</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
413+
<div class="highlight"><pre><span class="lineno">135</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_b</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">embedding_dim</span><span class="p">,</span> <span class="n">r</span><span class="p">)))</span>
414+
<span class="lineno">136</span>
415+
<span class="lineno">137</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
418416
</div>
419417
</div>
420418
<div class='section' id='section-26'>
@@ -426,7 +424,7 @@ <h2>LoRA Embedding Layer</h2>
426424

427425
</div>
428426
<div class='code'>
429-
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_a</span><span class="p">)</span></pre></div>
427+
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_a</span><span class="p">)</span></pre></div>
430428
</div>
431429
</div>
432430
<div class='section' id='section-27'>
@@ -438,7 +436,7 @@ <h2>LoRA Embedding Layer</h2>
438436

439437
</div>
440438
<div class='code'>
441-
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">zeros_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_b</span><span class="p">)</span></pre></div>
439+
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">zeros_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_b</span><span class="p">)</span></pre></div>
442440
</div>
443441
</div>
444442
<div class='section' id='section-28'>
@@ -449,7 +447,7 @@ <h2>LoRA Embedding Layer</h2>
449447

450448
</div>
451449
<div class='code'>
452-
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
450+
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
453451
</div>
454452
</div>
455453
<div class='section' id='section-29'>
@@ -461,7 +459,7 @@ <h2>LoRA Embedding Layer</h2>
461459

462460
</div>
463461
<div class='code'>
464-
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">result</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span></pre></div>
462+
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">result</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span></pre></div>
465463
</div>
466464
</div>
467465
<div class='section' id='section-30'>
@@ -473,7 +471,7 @@ <h2>LoRA Embedding Layer</h2>
473471

474472
</div>
475473
<div class='code'>
476-
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">result</span> <span class="o">+=</span> <span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_a</span><span class="o">.</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_b</span><span class="o">.</span><span class="n">T</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaling</span></pre></div>
474+
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">result</span> <span class="o">+=</span> <span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_a</span><span class="o">.</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_b</span><span class="o">.</span><span class="n">T</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaling</span></pre></div>
477475
</div>
478476
</div>
479477
<div class='section' id='section-31'>
@@ -485,7 +483,7 @@ <h2>LoRA Embedding Layer</h2>
485483

486484
</div>
487485
<div class='code'>
488-
<div class="highlight"><pre><span class="lineno">152</span> <span class="k">return</span> <span class="n">result</span></pre></div>
486+
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">return</span> <span class="n">result</span></pre></div>
489487
</div>
490488
</div>
491489
<div class='footer'>

labml_nn/lora/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def __init__(self, num_embeddings: int, embedding_dim: int,
123123
if alpha is None:
124124
alpha = r
125125

126-
nn.Embedding
127126
# The pre-trained embedding weights $W_0^T$ (frozen)
128127
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
129128
self.weight.requires_grad = False

0 commit comments

Comments
 (0)