Skip to content

Commit b5c9dcf

Browse files
authored
example: generating data for large-scale pretraining (#13)
* . * .
1 parent 96489bd commit b5c9dcf

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed

generate_training_data/README.md

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Generate Codebase Pre-Training Data
2+
3+
[![Documentation](https://img.shields.io/badge/docs-docs.codegen.com-blue)](https://docs.codegen.com/tutorials/generate-training-data)
4+
5+
This example demonstrates how to use Codegen to generate training data for large-scale LLM pre-training by extracting function implementations along with their dependencies and usages. The approach is inspired by node2vec, leveraging code graphs for learning.
6+
7+
## What This Example Does
8+
9+
The script analyzes your codebase and generates training data by:
10+
11+
1. **Finding All Functions**
12+
- Scans the entire codebase to identify function definitions
13+
- Filters out trivial functions (less than 2 lines)
14+
15+
2. **Capturing Implementation Context**
16+
```python
17+
{
18+
"implementation": {
19+
"source": "def process_data():\n ...",
20+
"filepath": "src/process.py"
21+
}
22+
}
23+
```
24+
25+
3. **Extracting Dependencies**
26+
```python
27+
{
28+
"dependencies": [
29+
{
30+
"source": "def helper_function():\n ...",
31+
"filepath": "src/helpers.py"
32+
}
33+
]
34+
}
35+
```
36+
37+
4. **Recording Usages**
38+
```python
39+
{
40+
"usages": [
41+
{
42+
"source": "result = process_data()",
43+
"filepath": "src/main.py"
44+
}
45+
]
46+
}
47+
```
48+
49+
## Running the Example
50+
51+
```bash
52+
# Install Codegen
53+
pip install codegen
54+
55+
# Run the data generation
56+
python run.py
57+
```
58+
59+
The script will analyze your codebase and output a `training_data.json` file containing the structured training data.
60+
61+
## Understanding the Code
62+
63+
- `run.py` - The main script that generates the training data
64+
- Uses `get_function_context()` to extract implementation, dependencies, and usages
65+
- Processes each function and builds a comprehensive context graph
66+
- Outputs structured JSON data with metadata about the processing
67+
68+
## Output Format
69+
70+
The generated `training_data.json` follows this structure:
71+
```json
72+
{
73+
"functions": [
74+
{
75+
"implementation": { "source": "...", "filepath": "..." },
76+
"dependencies": [{ "source": "...", "filepath": "..." }],
77+
"usages": [{ "source": "...", "filepath": "..." }]
78+
}
79+
],
80+
"metadata": {
81+
"total_functions": 100,
82+
"total_processed": 85,
83+
"avg_dependencies": 2.5,
84+
"avg_usages": 3.2
85+
}
86+
}
87+
```
88+
89+
## Learn More
90+
91+
- [Full Tutorial](https://docs.codegen.com/tutorials/generate-training-data)
92+
- [Code Model Pre-training](https://docs.codegen.com/concepts/code-model-training)
93+
- [Codegen Documentation](https://docs.codegen.com)

generate_training_data/run.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
3+
import codegen
4+
from codegen import Codebase
5+
from codegen.sdk.core.external_module import ExternalModule
6+
from codegen.sdk.core.import_resolution import Import
7+
from codegen.sdk.core.symbol import Symbol
8+
9+
10+
def hop_through_imports(imp: Import) -> Symbol | ExternalModule:
11+
"""Finds the root symbol for an import"""
12+
if isinstance(imp.imported_symbol, Import):
13+
return hop_through_imports(imp.imported_symbol)
14+
return imp.imported_symbol
15+
16+
17+
def get_function_context(function) -> dict:
18+
"""Get the implementation, dependencies, and usages of a function."""
19+
context = {
20+
"implementation": {"source": function.source, "filepath": function.filepath},
21+
"dependencies": [],
22+
"usages": [],
23+
}
24+
25+
# Add dependencies
26+
for dep in function.dependencies:
27+
# Hop through imports to find the root symbols ource
28+
if isinstance(dep, Import):
29+
dep = hop_through_imports(dep)
30+
31+
context["dependencies"].append({"source": dep.source, "filepath": dep.filepath})
32+
33+
# Add usages
34+
for usage in function.usages:
35+
context["usages"].append(
36+
{
37+
"source": usage.usage_symbol.source,
38+
"filepath": usage.usage_symbol.filepath,
39+
}
40+
)
41+
42+
return context
43+
44+
45+
@codegen.function("generate-training-data")
46+
def run(codebase: Codebase):
47+
"""Generate training data using a node2vec-like approach for code embeddings.
48+
49+
This codemod:
50+
1. Finds all functions in the codebase
51+
2. For each function:
52+
- Captures its implementation
53+
- Lists all dependencies (with their implementations)
54+
- Lists all usages (with their implementations)
55+
3. Outputs structured JSON data for training
56+
"""
57+
# Track all function contexts
58+
training_data = {
59+
"functions": [],
60+
"metadata": {
61+
"total_functions": len(codebase.functions),
62+
"total_processed": 0,
63+
"avg_dependencies": 0,
64+
"avg_usages": 0,
65+
},
66+
}
67+
68+
# Process each function in the codebase
69+
for function in codebase.functions:
70+
# Skip if function is too small
71+
if len(function.source.split("\n")) < 2:
72+
continue
73+
74+
# Get function context
75+
context = get_function_context(function)
76+
77+
# Only keep functions with enough context
78+
if len(context["dependencies"]) + len(context["usages"]) > 0:
79+
training_data["functions"].append(context)
80+
81+
# Update metadata
82+
training_data["metadata"]["total_processed"] = len(training_data["functions"])
83+
if training_data["functions"]:
84+
training_data["metadata"]["avg_dependencies"] = sum(
85+
len(f["dependencies"]) for f in training_data["functions"]
86+
) / len(training_data["functions"])
87+
training_data["metadata"]["avg_usages"] = sum(
88+
len(f["usages"]) for f in training_data["functions"]
89+
) / len(training_data["functions"])
90+
91+
# Print stats
92+
print(f"Processed {training_data['metadata']['total_processed']} functions")
93+
print(f"Average dependencies: {training_data['metadata']['avg_dependencies']:.2f}")
94+
print(f"Average usages: {training_data['metadata']['avg_usages']:.2f}")
95+
96+
return training_data
97+
98+
99+
if __name__ == "__main__":
100+
print("Initializing codebase...")
101+
codebase = Codebase.from_repo("fastapi/fastapi")
102+
103+
print("Generating training data...")
104+
training_data = run(codebase)
105+
106+
print("Saving training data...")
107+
with open("training_data.json", "w") as f:
108+
json.dump(training_data, f, indent=2)
109+
print("Training data saved to training_data.json")

0 commit comments

Comments
 (0)