Replies: 1 comment 5 replies
-
The backward pass is actually more expensive as well. You still need to backprop through the If all you want is to do is apply LoRA to GPT-2, I have an example doing that with my LoRA implementation, Lorax, here. That being said, if you're using an A100x8 node it's probably just best to go with full finetuning. |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I want to apply LoRA to GPT-2, so I created a counterpart of nanoGPT in JAX. Here is the repo: smolGPT. However, it turns out that each step in LoRA finetuning takes ~50% longer than full finetuning on an A100x8 node, completely defeating its purpose!
To run the repo, simply do
make .venv && make train
. As a prerequisite, you need to preparedata/openwebtext/train.bin
anddata/openwebtext/val.bin
according to the instruction in nanoGPT and copy them over. The default parameters intrain.py
does finetuning with a one-rank update. You can uselora_rank: Optional[int] = None
to do full finetuning.The training logic is in
train.py
and the GPT-2 model is defined insmolGPT/model.py
. Here is a quick overview of my implementation of LoRA. Let's say the original parameters look like{"linear": {"w": *, "b": *}}
. We first add an emptyuv
key by passing it throughinject_uv
. When LoRA is enabled,init_lora
initialize the low-rank parameters{"linear": {"w": None, "b": None, "uv": (*, *)}}
. Eventually,frozen
has structure{"linear": {"w": *, "b": *, "uv": None}}
whereasparams
has structure{"linear": {"w": None, "b": None, "uv": (*, *)}}
.During the training step, we merge
frozen
andparams
together and pass them into the model. However, we only calculate the gradient with respect top
, i.e.params
.The model then does the LoRA update when
uv
is present.Why is LoRA slower? I understand that
+ (x @ u) @ v
requires extra computation during the forward pass, but the backward pass generates much less gradient information, so I would expect a speedup. It would be great if you can take a look at my code and see if I did anything inefficient. Other general suggestions are also appreciated!Bonus question: Each full finetuning step takes 20-30% longer than a step in nanoGPT (its counterpart in PyTorch). Can you spot the reason? I attempted to do some profiling but the trace contains very little information about GPU usage, so it didn't really help me understand the situation.
Beta Was this translation helpful? Give feedback.
All reactions