|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import argparse |
| 9 | +from typing import List, Optional |
9 | 10 |
|
10 | 11 | import torch |
11 | 12 |
|
|
16 | 17 | # ============================================================================= |
17 | 18 |
|
18 | 19 |
|
| 20 | +def _parse_lora_layer_indices(value: str) -> Optional[List[int]]: |
| 21 | + """ |
| 22 | + Parse LoRA layer indices from comma/range syntax. |
| 23 | +
|
| 24 | + Supported formats: |
| 25 | + - "20-31" |
| 26 | + - "0,1,2,28,29,30,31" |
| 27 | + - "0-3,28-31" |
| 28 | + """ |
| 29 | + if value is None: |
| 30 | + return None |
| 31 | + |
| 32 | + raw = value.strip() |
| 33 | + if not raw: |
| 34 | + return None |
| 35 | + |
| 36 | + indices: List[int] = [] |
| 37 | + parts = [part.strip() for part in raw.split(",") if part.strip()] |
| 38 | + |
| 39 | + try: |
| 40 | + for part in parts: |
| 41 | + if "-" in part: |
| 42 | + start_s, end_s = part.split("-", 1) |
| 43 | + start = int(start_s.strip()) |
| 44 | + end = int(end_s.strip()) |
| 45 | + if start > end: |
| 46 | + raise argparse.ArgumentTypeError( |
| 47 | + f"Invalid range '{part}': start must be <= end" |
| 48 | + ) |
| 49 | + indices.extend(range(start, end + 1)) |
| 50 | + else: |
| 51 | + indices.append(int(part)) |
| 52 | + except ValueError as e: |
| 53 | + raise argparse.ArgumentTypeError( |
| 54 | + f"Invalid --lora-layer-indices value '{value}': {e}" |
| 55 | + ) from e |
| 56 | + |
| 57 | + if not indices: |
| 58 | + return None |
| 59 | + if any(idx < 0 for idx in indices): |
| 60 | + raise argparse.ArgumentTypeError( |
| 61 | + f"Invalid --lora-layer-indices '{value}': indices must be >= 0" |
| 62 | + ) |
| 63 | + |
| 64 | + return sorted(set(indices)) |
| 65 | + |
| 66 | + |
19 | 67 | def add_model_args(parser: argparse.ArgumentParser) -> None: |
20 | 68 | """Add model-related arguments.""" |
21 | 69 | group = parser.add_argument_group("Model") |
@@ -225,6 +273,15 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None: |
225 | 273 | default=None, |
226 | 274 | help="Module names to apply LoRA to (default: q_proj v_proj)", |
227 | 275 | ) |
| 276 | + group.add_argument( |
| 277 | + "--lora-layer-indices", |
| 278 | + type=_parse_lora_layer_indices, |
| 279 | + default=None, |
| 280 | + help=( |
| 281 | + "Optional layer indices to apply LoRA to, e.g. '20-31' or " |
| 282 | + "'0-3,28-31'. If omitted, applies to all matching layers." |
| 283 | + ), |
| 284 | + ) |
228 | 285 |
|
229 | 286 |
|
230 | 287 | def add_distributed_args(parser: argparse.ArgumentParser) -> None: |
@@ -373,6 +430,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: |
373 | 430 | lora_alpha=getattr(args, "lora_alpha", 32), |
374 | 431 | lora_dropout=getattr(args, "lora_dropout", 0.05), |
375 | 432 | lora_target_modules=getattr(args, "lora_target_modules", None), |
| 433 | + lora_layer_indices=getattr(args, "lora_layer_indices", None), |
376 | 434 | vllm_config_path=getattr(args, "vllm_config_path", None), |
377 | 435 | debug_loading=getattr(args, "debug_loading", False), |
378 | 436 | benchmark=getattr(args, "benchmark", False), |
|
0 commit comments