Skip to content

Commit fee1b2d

Browse files
authored
Quantize lora linears (#15935)
### Summary LoraLinears contain: 1. base weight (nn.Linear) 2. lora_a (nn.Linear) 3. lora_b (nn.Linear) (2) and (3) are caught by the filter, but (1) is not, as the weight and bias are pulled out of the nn.Linear and placed into nn.Parameters, and the linear is performed manually. This is for checkpoint compatibility - otherwise we'd have to map the weights for any lora model. See: https://github.com/pytorch/executorch/blob/b4d72f1e271915e9c0e1d313753a1eec840fbdee/examples/models/llama/lora.py#L31-L37 This PR adds lora linears into the quantization filter. ### Test plan ``` python -m extension.llm.export.export_llm \ base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ base.params="${DOWNLOADED_PATH}/params.json" \ base.adapter_checkpoint="../et_docs_7_epoch/adapter_model.safetensors" \ base.adapter_config="../et_docs_7_epoch/adapter_config.json" \ base.tokenizer_path="../et_docs_7_epoch/" \ model.use_kv_cache=true \ model.use_sdpa_with_kv_cache=true \ ``` Confirm output model size is ~1.7GB instead of 5.1GB. ``` (executorch) [[email protected] /data/users/lfq/executorch (lfq.quantize-lora-linears)]$ ls -la *.pte -rw-r--r-- 1 lfq users 5106135168 Nov 20 15:59 et_lora.pte -rw-r--r-- 1 lfq users 1733835776 Nov 20 17:07 et_lora_fix.pte ```
1 parent a4298ac commit fee1b2d

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,27 @@ def quantize( # noqa C901
159159
from torchao.utils import unwrap_tensor_subclass
160160

161161
def filter_fn(m, fqn):
162+
# Check if it's a regular nn.Linear
162163
is_linear = isinstance(m, nn.Linear)
164+
165+
# Check if it's a LoRALinear (which has a base weight parameter to quantize)
166+
is_lora_linear = False
167+
try:
168+
from executorch.examples.models.llama.lora import LoRALinear
169+
170+
is_lora_linear = isinstance(m, LoRALinear)
171+
except ImportError:
172+
pass
173+
174+
# Check if the weight shape is compatible with group size
163175
has_shape_compatible_with_group_size = False
164-
if is_linear:
176+
if is_linear or is_lora_linear:
165177
has_shape_compatible_with_group_size = (
166178
m.weight.shape[1] % group_size == 0
167179
)
168-
return is_linear and has_shape_compatible_with_group_size
180+
return (
181+
is_linear or is_lora_linear
182+
) and has_shape_compatible_with_group_size
169183

170184
quantize_(
171185
model,

0 commit comments

Comments
 (0)