A generic text classification framework using Gemini 2.0 Flash in JSON mode. Customize the input models, output schema, and data sources to adapt to your classification needs.
Before you begin, make sure you have the following installed:
-
uv: This project uses
uv
for package management. Install it using the following command:curl -LsSf https://astral.sh/uv/install.sh | sh
See the uv documentation for more details.
-
Python: Install Python using
uv
:uv python install
-
Clone the repository:
git clone https://github.com/chriscarrollsmith/llm-classifier.git cd llm-classifier
-
Install the project dependencies using
uv
:uv sync
-
Create
.env
:
GEMINI_API_KEY=your_key_here
DB_PATH=classifications.db
CONCURRENCY_LIMIT=1
- Implement your custom components in
prompt.py
andmain.py
class Input(SQLModel, table=False):
"""Custom fields for your input data"""
title: str
body: str
author: str
class Response(SQLModel, table=False):
"""Define your classification schema"""
key_insight: str
severity: int
category: str
Requirements:
- All required
Input
fields must be present as{placeholder}
variables - Include example JSON matching your
Response
model - Provide clear formatting and classification instructions
PROMPT_TEMPLATE = """
Analyze this post from {author}:
{title}
{body}
Return JSON with:
- "key_insight" (most important finding)
- "severity" (1-10)
- "category" (most relevant topic)
Example:
{{
"key_insight": "Example insight",
"severity": 7,
"category": "Technology"
}}
"""
In main.py
, define the document types to process:
seed_input_types(session, input_types=["Blogs", "Tweets"])
There must be at least one input type.
Choose a strategy based on your API:
Bulk Download Approach:
class CustomDownloader(Downloader):
@classmethod
@override
def get_records(cls, input_type: InputType) -> list[ClassificationInput]:
response = requests.get('https://api.example.com/data')
return [ClassificationInput(
body=item["content"],
title=item["title"],
author=item["author"],
input_type_id=input_type.id
) for item in response.json()]
Per-Record Approach:
class CustomDownloader(Downloader):
@classmethod
@override
def get_record_ids(cls, input_type) -> list[int]:
ids = requests.get('https://api.example.com/items/list').json()
return ids
@classmethod
@override
def get_record(cls, record_id: int) -> ClassificationInput:
item = requests.get(f'https://api.example.com/items/{record_id}').json()
return ClassificationInput(
body=item["content"],
title=item["title"],
author=item["author"],
input_type_id=input_type.id
)
Summarization:
To use print_summary_statistics
, you must have at least one numeric field in your Response
model. Otherwise, you should delete or comment out the print_summary_statistics
call in main.py
.
print_summary_statistics(
session,
numeric_field="severity", # Name of your numeric response field
breakpoints=4 # Percentile scale breakpoints (e.g., 4 prints quartiles)
)
Export Filtering:
To filter exported responses, you may optionally add a list of SQLAlchemy expressions to the where_clauses
argument of the export_responses
function. Each clause should be an expression that filters the responses. In the input_fields
argument, you should specify any fields from the Input
model that you want to include with the Response
data in the exported CSV.
from sqlalchemy import and_
export_responses(
session,
"results.csv",
where_clauses=[
ClassificationResponse.severity >= 7,
ClassificationResponse.category == "Security"
],
input_fields=["id", "processed_date", "title"]
)
The prompt.py
and main.py
files in this repo contain an example implementation of the framework that downloads and processes data from the public JSONPlaceholder API. To use the framework for your own classification needs, follow these steps:
- Change the
Input
andResponse
models to match your use case - Create a prompt template with required placeholders
- Implement a data downloader for your API
- Configure input types in
main.py
- Customize export filters and summary fields
The framework handles:
- Database management
- Parallel LLM API calls with rate limiting and retries
- Response parsing and validation