Skip to content

Commit d7d0f8b

Browse files
committed
Use save_weights_for_sampler for eval and add LoRA rank to client creation
1 parent 269008b commit d7d0f8b

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

config_schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class TrainingConfig(BaseModel):
6161
max_seq_length: int = Field(
6262
default=2048, ge=128, le=32768, description="Maximum sequence length"
6363
)
64+
lora_rank: int = Field(
65+
default=16, ge=1, le=256, description="LoRA rank (adapter dimension)"
66+
)
6467

6568
@field_validator("train_file")
6669
@classmethod

trainer_with_eval.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,11 @@ async def async_main(config_path: str) -> None:
208208
except ValueError as e:
209209
print(f"Warning: Could not initialize EvalOps client: {e}")
210210

211-
print(f"Creating LoRA training client for {base_model}...")
212-
training_client = service_client.create_lora_training_client(base_model=base_model)
211+
print(f"Creating LoRA training client for {base_model} (rank={config.lora_rank})...")
212+
training_client = service_client.create_lora_training_client(
213+
base_model=base_model,
214+
rank=config.lora_rank,
215+
)
213216

214217
tokenizer = training_client.get_tokenizer()
215218

0 commit comments

Comments
 (0)