Skip to content

Commit 3ee0d56

Browse files
author
Ubuntu
committed
added docstrings and annotations
1 parent 04b2608 commit 3ee0d56

File tree

10 files changed

+306
-29
lines changed

10 files changed

+306
-29
lines changed

app.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1+
"""
2+
Module for loading a LoRA fine-tuned BART model and serving
3+
an interactive Gradio interface for text generation.
4+
"""
5+
16
import torch
27
import gradio as gr
38
from transformers import AutoTokenizer
49
from transformers import BartForConditionalGeneration
510
from peft import PeftModel
611

712

8-
def load_model():
13+
def load_model() -> tuple[AutoTokenizer, PeftModel, torch.device]:
914
"""
10-
Load environment variables, tokenizer, and the fine-tuned LoRA model.
15+
Load tokenizer and LoRA-enhanced model onto available device.
16+
17+
Returns:
18+
tokenizer (AutoTokenizer): Tokenizer for text processing.
19+
model (PeftModel): Fine-tuned LoRA BART model in eval mode.
20+
device (torch.device): Computation device (GPU if available, else CPU).
1121
"""
1222
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1323

@@ -33,7 +43,13 @@ def load_model():
3343

3444
def predict(text: str) -> str:
3545
"""
36-
Generate a response for a single input text.
46+
Generate a text response given an input prompt.
47+
48+
Args:
49+
text (str): The input prompt string.
50+
51+
Returns:
52+
str: The decoded model output.
3753
"""
3854
# Tokenize and move inputs to device
3955
inputs = tokenizer(
@@ -62,7 +78,10 @@ def predict(text: str) -> str:
6278
return tokenizer.decode(outputs[0], skip_special_tokens=True)
6379

6480

65-
def main():
81+
def main() -> None:
82+
"""
83+
Launch Gradio web interface for interactive model inference.
84+
"""
6685
interface = gr.Interface(
6786
fn=predict,
6887
inputs=gr.Textbox(lines=5, placeholder="Ask a Question", label="Your Question"),

scripts/download_data.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
1+
"""
2+
Download Reddit Q&A posts, preprocess them, and split into train/validation/test sets.
3+
"""
4+
15
import json
26
import pandas as pd
37
import argparse
48
from pathlib import Path
59
from bart_reddit_lora.data import scrape, preprocess, split_and_save
610

711

8-
def parse_args():
9-
"""Parse command-line arguments for downloading and preprocessing data."""
12+
def parse_args() -> argparse.Namespace:
13+
"""
14+
Parse command-line arguments for data scraping and preprocessing.
15+
16+
Returns:
17+
argparse.Namespace: Parsed arguments with attributes:
18+
config (str): Path to subreddit size map JSON.
19+
raw_dir (str): Directory to save raw data.
20+
out_dir (str): Directory to save processed data.
21+
"""
1022
p = argparse.ArgumentParser(prog="download-data")
1123

1224
p.add_argument("--config", default="data/subreddit_size_map.json")
@@ -16,7 +28,17 @@ def parse_args():
1628
return p.parse_args()
1729

1830

19-
def main():
31+
def main() -> None:
32+
"""
33+
Execute the data pipeline: scrape Reddit posts, preprocess, and split into datasets.
34+
35+
Steps:
36+
1. Parse arguments for paths and config.
37+
2. Create directories for raw and processed data.
38+
3. Load subreddit size map from JSON.
39+
4. Scrape posts and save raw JSON.
40+
5. Preprocess scraped data and split into train/val/test sets.
41+
"""
2042
cfg = parse_args()
2143

2244
# create the paths for dataset dirs

scripts/sample_data.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
1+
"""Utilities for creating a reproducible smoke‐test sample from a larger CSV dataset."""
2+
13
import argparse
24
import sys
35
import pandas as pd
46
from pathlib import Path
57

68

7-
def parse_args():
8-
"""Parse command-line arguments for sampling data for smoke tests."""
9+
def parse_args() -> argparse.Namespace:
10+
"""
11+
Parse command-line arguments.
12+
13+
Returns:
14+
argparse.Namespace:
15+
- input (str): Path to the full CSV file to sample from.
16+
- output (str): Path where the sampled CSV will be written.
17+
- n (int): Number of examples to sample.
18+
- seed (int): Random seed for reproducibility.
19+
"""
920
p = argparse.ArgumentParser(
1021
description="Create a smoke-test sample from a larger CSV"
1122
)
@@ -30,7 +41,23 @@ def parse_args():
3041
return p.parse_args()
3142

3243

33-
def sample_dataset(input_csv: Path, output_csv: Path, sample_size: int, seed: int = 42):
44+
def sample_dataset(
45+
input_csv: Path,
46+
output_csv: Path,
47+
sample_size: int,
48+
seed: int = 42):
49+
"""
50+
Load a CSV, draw a random sample, and write it out.
51+
52+
Args:
53+
input_csv (Path): Path to the source CSV file.
54+
output_csv (Path): Path where the sampled CSV will be saved.
55+
sample_size (int): Number of rows to sample without replacement.
56+
seed (int, optional): Random seed for sampling. Defaults to 42.
57+
58+
Raises:
59+
SystemExit: If `sample_size` exceeds the number of available rows.
60+
"""
3461
# load full train set
3562
df = pd.read_csv(input_csv)
3663
total = len(df)
@@ -51,6 +78,9 @@ def sample_dataset(input_csv: Path, output_csv: Path, sample_size: int, seed: in
5178

5279

5380
def main():
81+
"""
82+
Entry point: parse arguments and run the sampling routine.
83+
"""
5484
cfg = parse_args()
5585

5686
sample_dataset(cfg.input, cfg.output, cfg.n, cfg.seed)

scripts/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""
2+
CLI entry point for the BART Reddit LoRA training module.
3+
4+
This script imports the `main` function from `bart_reddit_lora.train`
5+
and executes it when run directly.
6+
"""
7+
18
from bart_reddit_lora.train import main
29

310
if __name__ == "__main__":

src/bart_reddit_lora/data.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
Module for scraping Reddit Q/A pairs, cleaning text, splitting data,
3+
and tokenizing for model training.
4+
"""
5+
16
import os
27
import re
38
import time
@@ -7,10 +12,22 @@
712
from pathlib import Path
813
from praw import Reddit
914
from transformers import AutoTokenizer
10-
from typing import Union, Tuple
15+
from typing import Dict, List, Any, Union, Tuple
16+
import pandas as pd
17+
18+
19+
def init_reddit() -> Reddit:
20+
"""
21+
Initialize and return a Reddit client using environment variables.
1122
23+
Environment Variables:
24+
REDDIT_CLIENT_ID: Reddit API client ID.
25+
REDDIT_CLIENT_SECRET: Reddit API client secret.
26+
REDDIT_USER_AGENT: User agent string for Reddit API.
1227
13-
def init_reddit() -> None:
28+
Returns:
29+
An authenticated praw.Reddit instance.
30+
"""
1431
return Reddit(
1532
client_id=os.environ["REDDIT_CLIENT_ID"],
1633
client_secret=os.environ["REDDIT_CLIENT_SECRET"],
@@ -19,6 +36,15 @@ def init_reddit() -> None:
1936

2037

2138
def clean_text(txt: str) -> str:
39+
"""
40+
Clean a text string by removing HTML, code fences, URLs, emojis, quotes, and extra whitespace.
41+
42+
Args:
43+
txt: Raw text to be cleaned.
44+
45+
Returns:
46+
A cleaned text string.
47+
"""
2248
# strip HTML/Markdown
2349
txt = BeautifulSoup(txt, "html.parser").get_text()
2450
# remove code fences
@@ -36,7 +62,16 @@ def clean_text(txt: str) -> str:
3662
return txt
3763

3864

39-
def scrape(sub_size_map):
65+
def scrape(sub_size_map: Dict[str, int]) -> List[Dict[str, Any]]:
66+
"""
67+
Scrape top posts and their highest-quality comments from specified subreddits.
68+
69+
Args:
70+
sub_size_map: Mapping of subreddit names to sample sizes.
71+
72+
Returns:
73+
List of dicts each containing 'id', 'subreddit', 'question', 'answer', and 'url'.
74+
"""
4075
reddit = init_reddit()
4176
qa = [] # the Q/A posts to train the model
4277

@@ -149,7 +184,16 @@ def _comment_quality(c):
149184
return qa
150185

151186

152-
def preprocess(qa_raw):
187+
def preprocess(qa_raw: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
188+
"""
189+
Apply text cleaning to raw Q/A entries.
190+
191+
Args:
192+
qa_raw: List of dicts with raw 'question' and 'answer' fields.
193+
194+
Returns:
195+
Cleaned list with same keys plus cleaned text.
196+
"""
153197
cleaned = []
154198
for item in qa_raw:
155199
q = clean_text(item["question"])
@@ -166,7 +210,14 @@ def preprocess(qa_raw):
166210
return cleaned
167211

168212

169-
def split_and_save(df, out_dir: Union[str, Path]):
213+
def split_and_save(df: pd.DataFrame, out_dir: Union[str, Path]) -> None:
214+
"""
215+
Shuffle, split, and save DataFrame into train/validation/test CSV files.
216+
217+
Args:
218+
df: DataFrame containing Q/A data.
219+
out_dir: Directory path to save CSV files.
220+
"""
170221
# create the dir path if not existing
171222
out_dir = Path(out_dir)
172223
out_dir.mkdir(parents=True, exist_ok=True)
@@ -195,6 +246,18 @@ def tokenize_and_format(
195246
max_input_length: int = 512, # max 1024 1024
196247
max_target_length: int = 128, # max 1024 800
197248
) -> Tuple[DatasetDict, AutoTokenizer]:
249+
"""
250+
Tokenize and format a DatasetDict for model training.
251+
252+
Args:
253+
ds: DatasetDict with 'question' and 'answer' columns.
254+
checkpoint: Pretrained tokenizer checkpoint identifier.
255+
max_input_length: Maximum input token length.
256+
max_target_length: Maximum target token length.
257+
258+
Returns:
259+
Tuple of tokenized DatasetDict and the tokenizer.
260+
"""
198261
tok = AutoTokenizer.from_pretrained(checkpoint)
199262

200263
def _preprocess_batch(examples):

src/bart_reddit_lora/evaluation.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,49 @@
1+
"""
2+
Metrics computation module for sequence-to-sequence models.
3+
4+
This module provides a factory function to create a `compute_metrics` callable
5+
for Hugging Face's `Trainer`. The returned function computes ROUGE-L, BLEU, and
6+
BERTScore (F1) on decoded model predictions versus labels.
7+
"""
8+
19
import numpy as np
210
import evaluate
311
from transformers import EvalPrediction
12+
from typing import Callable, Dict, Any, Union
13+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
14+
415

16+
def build_compute_metrics(
17+
tok: PreTrainedTokenizerBase,
18+
num_process_workers: int = 2
19+
) -> Callable[[EvalPrediction], Dict[str, float]]:
20+
"""
21+
Create a metrics computation function for use with Hugging Face `Trainer`.
522
6-
def build_compute_metrics(tok, num_process_workers: int = 2):
7-
"""Return a closure that Hugging Face's Trainer can call."""
23+
Args:
24+
tokenizer: A Hugging Face tokenizer for decoding predictions/labels.
25+
num_process_workers: Number of worker processes for metric computation.
26+
27+
Returns:
28+
A callable that takes an `EvalPrediction` and returns a dict with:
29+
- "rougeL": ROUGE-L score (%)
30+
- "bleu": BLEU score (%)
31+
- "bertscore_f1": average BERTScore F1
32+
"""
833
rouge = evaluate.load("rouge") # longest-substring overlap
934
bleu = evaluate.load("bleu") # n-gram precision
1035
bertscore = evaluate.load("bertscore") # semantic similarity
1136

12-
def _compute_metrics(eval_pred: EvalPrediction):
37+
def _compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
38+
"""
39+
Compute ROUGE-L, BLEU, and BERTScore given model predictions and labels.
40+
41+
Args:
42+
eval_pred: An `EvalPrediction` with `predictions` and `label_ids`.
43+
44+
Returns:
45+
A dict mapping metric names to rounded scores.
46+
"""
1347
preds, labels = eval_pred.predictions, eval_pred.label_ids
1448

1549
# handle tuple output (some models return (generated_ids, ...))

0 commit comments

Comments
 (0)