Skip to content

Commit 975bd57

Browse files
authored
Add QLoRA and FP8 to finetuning tutorial (part 2) (#2542)
This is part 2 of the end-to-end tutorial. Previously we already had QAT. This commit also adds QLoRA and FP8. To preview, visit https://docs-preview.pytorch.org/pytorch/ao/2542/finetuning.html
1 parent c011bad commit 975bd57

File tree

1 file changed

+115
-2
lines changed

1 file changed

+115
-2
lines changed

docs/source/finetuning.rst

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,123 @@ schemes, but these are not customizable unlike the above example.
284284
Quantized Low-Rank Adaptation (QLoRA)
285285
#####################################
286286

287-
(Coming soon!)
287+
Low-Rank Adaptation (LoRA) refers to freezing the original model,
288+
and instead training a set of new "adapter" parameters that are a
289+
small fraction of the original parameters, thereby significantly
290+
reducing the memory footprint during training. QLoRA is an extension
291+
of LoRA that additionally quantizes the frozen original model
292+
parameters to 4-bits, thereby further reducing the memory footprint.
293+
294+
TorchAO offers an implementation of the NF4 data type proposed in
295+
the original `QLoRA paper <https://arxiv.org/pdf/2305.14314>`__.
296+
This implementation expresses NF4 as a tensor subclass through the
297+
`NF4Tensor <https://docs.pytorch.org/ao/stable/generated/torchao.dtypes.NF4Tensor.html>`__,
298+
which composes cleanly with other PyTorch features like `torch.compile`
299+
and FSDP2. Users can convert a high precision tensor to NF4 simply
300+
by calling `torchao.dtypes.to_nf4 <https://docs.pytorch.org/ao/stable/generated/torchao.dtypes.to_nf4.html>`__.
301+
For example:
302+
303+
.. code::
304+
305+
class FrozenNF4Linear(nn.Linear):
306+
def __init__(
307+
self,
308+
in_dim: int,
309+
out_dim: int,
310+
bias: bool = False,
311+
device: Optional[torch.device] = None,
312+
dtype: Optional[torch.dtype] = None,
313+
**quantization_kwargs,
314+
):
315+
super().__init__(in_dim, out_dim, bias=bias, device=device, dtype=dtype)
316+
# No need to train these in QLoRA
317+
self.weight.requires_grad_(False)
318+
if self.bias is not None:
319+
self.bias.requires_grad_(False)
320+
nf4_weight = to_nf4(self.weight, **quantization_kwargs)
321+
self.weight = torch.nn.Parameter(nf4_weight, requires_grad=False)
322+
323+
QLoRA need not work with NF4 specifically, though NF4 has been
324+
shown to achieve competitive results compared to bf16 baselines
325+
while significantly reducing the memory required for training.
326+
This technique can also compose with other lower bit dtypes
327+
such as regular INT4 or even newer `MXFP4 or NVFP4 <https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats>`__
328+
targeting Blackwell GPUs to reap similar memory benefits with
329+
varying tradeoffs.
330+
331+
Option 1: TorchTune Integration
332+
===============================
333+
334+
TorchTune incorporates the `NF4Tensor` in its QLoRA fine-tuning
335+
recipe through their implementation of `LoRALinear <https://github.com/pytorch/torchtune/blob/a6290a5b40758f13bca61c386bc8756a49ef417e/torchtune/modules/peft/lora.py#L19>`__.
336+
You can also try it out by running the following command,
337+
or refer to their `QLoRA tutorial <https://docs.pytorch.org/torchtune/stable/tutorials/qlora_finetune.html>`__
338+
for more details.
339+
340+
.. code::
341+
342+
tune run lora_finetune_single_device --config llama3_2/3B_qlora_single_device.yaml
343+
344+
Option 2: HuggingFace PEFT Integration
345+
======================================
346+
347+
`HuggingFace PEFT <https://huggingface.co/docs/peft/main/en/developer_guides/quantization#torchao-pytorch-architecture-optimization>`__
348+
also has a limited version of QLoRA leveraging TorchAO's INT8
349+
quantization, though INT4 or NF4 are not supported yet. Users
350+
can invoke this functionality by preparing their models as follows.
351+
For full details, please refer to `this tutorial <https://huggingface.co/docs/peft/main/en/developer_guides/quantization#torchao-pytorch-architecture-optimization>`__.
352+
353+
.. code::
354+
355+
from peft import LoraConfig, get_peft_model
356+
from transformers import AutoModelForCausalLM, TorchAoConfig
357+
from torchao.quantization import Int8WeightOnlyConfig
358+
359+
base_model = AutoModelForCausalLM.from_pretrained(
360+
"meta-llama/Llama-3.2-1B",
361+
quantization_config=TorchAoConfig(Int8WeightOnlyConfig()),
362+
)
363+
peft_config = LoraConfig()
364+
model = get_peft_model(base_model, peft_config)
288365
289366
290367
Float8 Quantized Fine-tuning
291368
############################
292369

293-
(Coming soon!)
370+
Similar to `pre-training <pretraining.html>`__, we can also
371+
leverage float8 in fine-tuning for higher training throughput
372+
with no accuracy degradation and no increase in memory usage.
373+
Float8 training is integrated into TorchTune's distributed
374+
full fine-tuning recipe, leveraging the same APIs as our
375+
integration with TorchTitan. Users can invoke this fine-tuning
376+
recipe as follows:
377+
378+
.. code::
379+
380+
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full
381+
enable_fp8_training=true \
382+
fp8_recipe_name=tensorwise \
383+
compile=True
384+
385+
Initial experiments saw up to 16.5% throughput improvement
386+
for fine-tuning Llama3.2-3B in float8:
387+
388+
.. code::
389+
390+
experiment_name tok/s peak_mem_reserved
391+
---------------------- ------------------- -------------------
392+
bf16 6502.143 (+0.000%) 30.090 (+0.000%)
393+
fp8_noname 7205.386 (+10.816%) 30.010 (-0.266%)
394+
fp8_tensorwise 7222.198 (+11.074%) 30.010 (-0.266%)
395+
fp8_rowwise 6387.968 (-1.756%) 29.158 (-3.096%)
396+
fp8_rowwise_with_gw_hp 7573.698 (+16.480%) 29.516 (-1.908%)
397+
398+
experiment_name hellaswag_acc wikitext_word_perplexity
399+
---------------------- --------------- --------------------------
400+
bf16 0.533 (+0.000) 12.407 (+0.000)
401+
fp8_noname 0.533 (+0.000) 12.414 (+0.007)
402+
fp8_tensorwise 0.533 (+0.000) 12.412 (+0.005)
403+
fp8_rowwise 0.533 (-0.000) 12.420 (+0.013)
404+
fp8_rowwise_with_gw_hp 0.534 (+0.001) 12.416 (+0.009)
405+
406+
Please refer to the `pre-training <pretraining.html>`__ tutorial for more details.

0 commit comments

Comments
 (0)