Skip to content

Commit b72a1fc

Browse files
committed
Apply black formatting
1 parent d70cb83 commit b72a1fc

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

scripts/sample_data.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def parse_args() -> argparse.Namespace:
1111
Parse command-line arguments.
1212
1313
Returns:
14-
argparse.Namespace:
15-
- input (str): Path to the full CSV file to sample from.
16-
- output (str): Path where the sampled CSV will be written.
17-
- n (int): Number of examples to sample.
14+
argparse.Namespace:
15+
- input (str): Path to the full CSV file to sample from.
16+
- output (str): Path where the sampled CSV will be written.
17+
- n (int): Number of examples to sample.
1818
- seed (int): Random seed for reproducibility.
1919
"""
2020
p = argparse.ArgumentParser(
@@ -41,11 +41,7 @@ def parse_args() -> argparse.Namespace:
4141
return p.parse_args()
4242

4343

44-
def sample_dataset(
45-
input_csv: Path,
46-
output_csv: Path,
47-
sample_size: int,
48-
seed: int = 42):
44+
def sample_dataset(input_csv: Path, output_csv: Path, sample_size: int, seed: int = 42):
4945
"""
5046
Load a CSV, draw a random sample, and write it out.
5147

src/bart_reddit_lora/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ def build_peft_model(
2828
r: int = 8,
2929
lora_dropout: float = 0.1,
3030
bias: str = "none",
31-
target_modules: Sequence[str] = ("q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"),
31+
target_modules: Sequence[str] = (
32+
"q_proj",
33+
"k_proj",
34+
"v_proj",
35+
"o_proj",
36+
"fc1",
37+
"fc2",
38+
),
3239
modules_to_save: Sequence[str] = ("lm_head",),
3340
) -> PeftModel:
3441
"""Wrap a base BART model with a LoRA adapter configuration.

0 commit comments

Comments
 (0)