Skip to content

Commit 301d0b2

Browse files
committed
model layer stuff
1 parent fae3f5b commit 301d0b2

File tree

4 files changed

+131
-1
lines changed

4 files changed

+131
-1
lines changed

example_trainer/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,48 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
570570
| `--lora-alpha` | 32 | LoRA alpha scaling factor |
571571
| `--lora-dropout` | 0.05 | LoRA dropout probability |
572572
| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) |
573+
| `--lora-layer-indices` | None | Optional layer filter (examples: `20-31`, `0-3,28-31`) |
574+
575+
### LoRA Layer Index Guide (by Architecture)
576+
577+
`--lora-layer-indices` is model-dependent. Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another.
578+
579+
| Architecture family | Common config fields | Typical layer list path | Notes |
580+
|---------------------|----------------------|-------------------------|-------|
581+
| LLaMA / Llama-2 / Llama-3 / Mistral | `num_hidden_layers` | `model.layers` | Most common causal-LM layout |
582+
| Qwen / Qwen2 / Qwen2.5 / Qwen3 | `num_hidden_layers` | `model.layers` | Similar layer indexing to LLaMA |
583+
| GPT-2 / GPT-J style | `n_layer` or mapped to `num_hidden_layers` | `transformer.h` | PEFT may use `h` pattern internally |
584+
| Falcon | `num_hidden_layers` | `transformer.h` | Uses `h` block list in model module tree |
585+
586+
#### Reliable way to check for any model
587+
588+
Always query the model config before choosing indices:
589+
590+
```bash
591+
python - <<'PY'
592+
from transformers import AutoConfig
593+
594+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
595+
cfg = AutoConfig.from_pretrained(model_id)
596+
num_layers = getattr(cfg, "num_hidden_layers", None)
597+
if num_layers is None:
598+
num_layers = getattr(cfg, "n_layer", None)
599+
600+
print(f"model={model_id}")
601+
print(f"num_hidden_layers={num_layers}")
602+
if num_layers is not None:
603+
print(f"valid index range: 0-{num_layers-1}")
604+
PY
605+
```
606+
607+
#### Practical presets
608+
609+
If your model has `N` layers:
610+
611+
- Full layers: omit `--lora-layer-indices`
612+
- Top 25%: `--lora-layer-indices {int(0.75*N)}-{N-1}`
613+
- Top 50%: `--lora-layer-indices {int(0.5*N)}-{N-1}`
614+
- Last 12 layers: `--lora-layer-indices {N-12}-{N-1}` (if `N >= 12`)
573615

574616
### vLLM Arguments
575617

example_trainer/cli.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import argparse
9+
from typing import List, Optional
910

1011
import torch
1112

@@ -16,6 +17,53 @@
1617
# =============================================================================
1718

1819

20+
def _parse_lora_layer_indices(value: str) -> Optional[List[int]]:
21+
"""
22+
Parse LoRA layer indices from comma/range syntax.
23+
24+
Supported formats:
25+
- "20-31"
26+
- "0,1,2,28,29,30,31"
27+
- "0-3,28-31"
28+
"""
29+
if value is None:
30+
return None
31+
32+
raw = value.strip()
33+
if not raw:
34+
return None
35+
36+
indices: List[int] = []
37+
parts = [part.strip() for part in raw.split(",") if part.strip()]
38+
39+
try:
40+
for part in parts:
41+
if "-" in part:
42+
start_s, end_s = part.split("-", 1)
43+
start = int(start_s.strip())
44+
end = int(end_s.strip())
45+
if start > end:
46+
raise argparse.ArgumentTypeError(
47+
f"Invalid range '{part}': start must be <= end"
48+
)
49+
indices.extend(range(start, end + 1))
50+
else:
51+
indices.append(int(part))
52+
except ValueError as e:
53+
raise argparse.ArgumentTypeError(
54+
f"Invalid --lora-layer-indices value '{value}': {e}"
55+
) from e
56+
57+
if not indices:
58+
return None
59+
if any(idx < 0 for idx in indices):
60+
raise argparse.ArgumentTypeError(
61+
f"Invalid --lora-layer-indices '{value}': indices must be >= 0"
62+
)
63+
64+
return sorted(set(indices))
65+
66+
1967
def add_model_args(parser: argparse.ArgumentParser) -> None:
2068
"""Add model-related arguments."""
2169
group = parser.add_argument_group("Model")
@@ -225,6 +273,15 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
225273
default=None,
226274
help="Module names to apply LoRA to (default: q_proj v_proj)",
227275
)
276+
group.add_argument(
277+
"--lora-layer-indices",
278+
type=_parse_lora_layer_indices,
279+
default=None,
280+
help=(
281+
"Optional layer indices to apply LoRA to, e.g. '20-31' or "
282+
"'0-3,28-31'. If omitted, applies to all matching layers."
283+
),
284+
)
228285

229286

230287
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
@@ -373,6 +430,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
373430
lora_alpha=getattr(args, "lora_alpha", 32),
374431
lora_dropout=getattr(args, "lora_dropout", 0.05),
375432
lora_target_modules=getattr(args, "lora_target_modules", None),
433+
lora_layer_indices=getattr(args, "lora_layer_indices", None),
376434
vllm_config_path=getattr(args, "vllm_config_path", None),
377435
debug_loading=getattr(args, "debug_loading", False),
378436
benchmark=getattr(args, "benchmark", False),

example_trainer/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ class TrainingConfig(BaseModel):
154154
"If None, defaults to ['q_proj', 'v_proj'] for most models."
155155
),
156156
)
157+
lora_layer_indices: Optional[List[int]] = Field(
158+
None,
159+
description=(
160+
"Optional list of transformer layer indices to apply LoRA to. "
161+
"If None, applies LoRA to all matching layers."
162+
),
163+
)
157164

158165
# === Single-Copy Mode Configuration ===
159166
single_copy: bool = Field(

example_trainer/model.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,41 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
200200
target_modules = config.lora_target_modules
201201
if target_modules is None:
202202
target_modules = ["q_proj", "v_proj"]
203+
layer_indices = config.lora_layer_indices
204+
205+
if layer_indices is not None:
206+
num_hidden_layers = getattr(base_model.config, "num_hidden_layers", None)
207+
if num_hidden_layers is None:
208+
raise RuntimeError(
209+
"Model config does not expose num_hidden_layers; cannot validate "
210+
"--lora-layer-indices for this architecture."
211+
)
212+
invalid = [idx for idx in layer_indices if idx >= num_hidden_layers]
213+
if invalid:
214+
raise ValueError(
215+
f"Invalid --lora-layer-indices {invalid} for model with "
216+
f"{num_hidden_layers} layers (valid range: 0-{num_hidden_layers - 1})"
217+
)
203218

204219
print(f"Applying LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
205220
print(f"Target modules: {target_modules}")
221+
if layer_indices is not None:
222+
print(
223+
f"Applying LoRA only to layers: {layer_indices} "
224+
f"(total {len(layer_indices)})"
225+
)
206226

207-
lora_config = LoraConfig(
227+
lora_kwargs = dict(
208228
task_type=TaskType.CAUSAL_LM,
209229
r=config.lora_r,
210230
lora_alpha=config.lora_alpha,
211231
lora_dropout=config.lora_dropout,
212232
target_modules=target_modules,
213233
bias="none",
214234
)
235+
if layer_indices is not None:
236+
lora_kwargs["layers_to_transform"] = layer_indices
237+
lora_config = LoraConfig(**lora_kwargs)
215238

216239
model = get_peft_model(base_model, lora_config)
217240
model.print_trainable_parameters()

0 commit comments

Comments
 (0)