Skip to content

Commit ef1289b

Browse files
authored
Merge pull request #1 from peremartra/April-1
Add bias visualization module with comprehensive metrics and utilities
2 parents bac3625 + 6ade254 commit ef1289b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+5231
-501
lines changed

.gitignore

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
*.egg-info/
23+
.installed.cfg
24+
*.egg
25+
26+
# Installer logs
27+
pip-log.txt
28+
pip-delete-this.txt
29+
30+
# Unit test / coverage reports
31+
htmlcov/
32+
.tox/
33+
.nox/
34+
.coverage
35+
.coverage.*
36+
.cache
37+
nosetests.xml
38+
coverage.xml
39+
*.cover
40+
.hypothesis/
41+
.pytest_cache/
42+
43+
# Jupyter Notebook
44+
.ipynb_checkpoints
45+
46+
# pyenv
47+
.python-version
48+
49+
# mypy
50+
.mypy_cache/
51+
.dmypy.json
52+
53+
# Pyre type checker
54+
.pyre/
55+
56+
# mkdocs build
57+
site/
58+
59+
# Local files
60+
howto.txt
61+
62+
# Visualization test output
63+
visualization_test_output/

CONTRIBUTING.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ By participating in this project, you agree to maintain a respectful and inclusi
2828
3. Install development dependencies:
2929
```bash
3030
pip install -e ".[dev]"
31+
32+
# For working on bias visualization
33+
pip install -e ".[viz]"
34+
35+
# For working on evaluation tools
36+
pip install -e ".[eval]"
3137
```
3238
4. Create a new branch for your feature or bugfix:
3339
```bash
@@ -98,6 +104,8 @@ For new features:
98104
- Add unit tests for each function or method
99105
- Add integration tests for interactions between components
100106
- Ensure tests cover both normal behavior and error cases
107+
- For bias visualization features, test both the numerical computations and visualization generation
108+
- Mock transformer models for unit tests to avoid requiring large model downloads
101109

102110
## Documentation
103111

@@ -114,15 +122,18 @@ Documentation is a crucial part of the project. Please follow these guidelines:
114122

115123
3. **README**: Update the README.md if your changes affect the installation, basic usage, or other key aspects.
116124

125+
4. **Visualization Examples**: When adding new visualization features, include visual examples in the documentation.
126+
117127
## Future Roadmap
118128

119129
OptiPFair is an evolving project with plans for several future enhancements. If you're interested in contributing to these areas, please join the discussion in the related issues:
120130

121131
1. **Attention Layer Pruning**: Implementation of structured pruning for attention mechanisms.
122-
2. **Bias visualisations**: Implement visualizations of bias in pair prompts.
132+
2. **Bias-aware Pruning**: Techniques that optimize for both efficiency and fairness.
123133
3. **Block Pruning**: Methods for pruning entire transformer blocks.
124134
4. **Evaluation Framework**: Comprehensive evaluation suite for pruned models.
125135
5. **Fine-tuning Integration**: Tools for fine-tuning after pruning.
136+
6. **Extended Bias Analysis**: Support for intersectional and multi-attribute bias analysis.
126137

127138
## Questions?
128139

README.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88
</h3>
99
</div>
1010

11-
A Python library for structured pruning of large language models, with a focus on GLU architectures.
11+
A Python library for structured pruning, and Bias visualization, of large language models, with a focus on GLU architectures and fairness analysis.
12+
1213

1314
## Overview
1415

1516
OptiPFair enables efficient pruning of large language models while maintaining their performance. It implements various structured pruning methods, starting with MLP pruning for GLU architectures (as used in models like LLaMA, Mistral, etc.).
1617

18+
1719
Key features:
1820
- GLU architecture-aware pruning that preserves model structure
1921
- Multiple neuron importance calculation methods
2022
- Support for both pruning percentage and target expansion rate
2123
- Simple Python API and CLI interface
2224
- Progress tracking and detailed statistics
25+
- **NEW**: Bias visualization tools to analyze and understand fairness issues
26+
2327

2428
## Installation
2529

@@ -62,6 +66,37 @@ print(f"Reduction: {stats['reduction']:,} parameters ({stats['percentage_reducti
6266
# Save pruned model
6367
pruned_model.save_pretrained("./pruned-llama-model")
6468
```
69+
### Bias Visualization
70+
71+
```python
72+
from transformers import AutoModelForCausalLM, AutoTokenizer
73+
from optipfair.bias import visualize_bias
74+
75+
# Load model and tokenizer
76+
model_name = "meta-llama/Llama-3.2-1B"
77+
model = AutoModelForCausalLM.from_pretrained(model_name)
78+
tokenizer = AutoTokenizer.from_pretrained(model_name)
79+
80+
# Define prompt pairs that differ in demographic attributes
81+
prompt_pairs = [
82+
("The white man walked through the neighborhood. The police officer thought he",
83+
"The Black man walked through the neighborhood. The police officer thought he")
84+
]
85+
86+
# Generate visualizations and get bias metrics
87+
_, metrics = visualize_bias(
88+
model,
89+
tokenizer,
90+
prompt_pairs=prompt_pairs,
91+
visualization_types=["mean_diff", "pca"],
92+
layers="first_middle_last",
93+
output_dir="./bias_analysis"
94+
)
95+
96+
# Print overall bias metrics
97+
overall = metrics["pair_1"]["metrics"]["overall_metrics"]
98+
print(f"Mean activation difference: {overall['mean_difference']:.6f}")
99+
```
65100

66101
### Command-Line Interface
67102

completebias_test.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""
2+
Comprehensive test script for OptiPFair bias visualization module.
3+
4+
This script tests all three main visualization types (mean differences, heatmap, and PCA)
5+
with a real model to verify functionality.
6+
"""
7+
8+
import os
9+
import torch
10+
from transformers import AutoModelForCausalLM, AutoTokenizer
11+
from optipfair.bias import (
12+
visualize_mean_differences,
13+
visualize_heatmap,
14+
visualize_pca,
15+
visualize_bias
16+
)
17+
18+
# Create output directory for saved visualizations
19+
OUTPUT_DIR = "./visualization_test_output"
20+
os.makedirs(OUTPUT_DIR, exist_ok=True)
21+
22+
def test_bias_visualizations():
23+
"""Test all bias visualization functions with a real model."""
24+
print("Loading model and tokenizer...")
25+
try:
26+
# Load a small model for testing - use any model you have access to
27+
model_name = "meta-llama/Llama-3.2-1B" # Change to any model you have access to
28+
model = AutoModelForCausalLM.from_pretrained(
29+
model_name,
30+
torch_dtype=torch.float16, # Use half precision for memory efficiency
31+
device_map="auto" # Let the library decide the best device allocation
32+
)
33+
tokenizer = AutoTokenizer.from_pretrained(model_name)
34+
35+
# Define test prompt pairs
36+
prompt_pairs = [
37+
(
38+
"The white doctor examined the patient. The nurse thought",
39+
"The Black doctor examined the patient. The nurse thought"
40+
)
41+
]
42+
prompt_pair = prompt_pairs[0] # Use the first pair for individual tests
43+
44+
print("\n" + "="*50)
45+
print("TESTING INDIVIDUAL VISUALIZATION FUNCTIONS")
46+
print("="*50)
47+
48+
# Test 1: Mean Differences Visualization
49+
print("\nTesting mean differences visualization...")
50+
try:
51+
visualize_mean_differences(
52+
model,
53+
tokenizer,
54+
prompt_pair,
55+
layer_type="mlp_output",
56+
layers="first_middle_last",
57+
output_dir=os.path.join(OUTPUT_DIR, "mean_diff"),
58+
figure_format="png"
59+
)
60+
print("✅ Mean differences visualization successful!")
61+
except Exception as e:
62+
print(f"❌ Mean differences visualization failed: {e}")
63+
64+
# Test 2: Heatmap Visualization
65+
# For heatmap we need a specific layer, so let's get layer 8 or the middle layer
66+
print("\nTesting heatmap visualization...")
67+
try:
68+
visualize_heatmap(
69+
model,
70+
tokenizer,
71+
prompt_pair,
72+
layer_key="mlp_output_layer_8", # Using middle layer - adjust if your model has fewer layers
73+
output_dir=os.path.join(OUTPUT_DIR, "heatmap"),
74+
figure_format="png"
75+
)
76+
print("✅ Heatmap visualization successful!")
77+
except Exception as e:
78+
print(f"❌ Heatmap visualization failed: {e}")
79+
# If the specific layer fails, try with layer 0 which should exist in any model
80+
print("Retrying with layer 0...")
81+
try:
82+
visualize_heatmap(
83+
model,
84+
tokenizer,
85+
prompt_pair,
86+
layer_key="mlp_output_layer_0",
87+
output_dir=os.path.join(OUTPUT_DIR, "heatmap"),
88+
figure_format="png"
89+
)
90+
print("✅ Heatmap visualization with layer 0 successful!")
91+
except Exception as e2:
92+
print(f"❌ Heatmap visualization with layer 0 also failed: {e2}")
93+
94+
# Test 3: PCA Visualization
95+
print("\nTesting PCA visualization...")
96+
try:
97+
visualize_pca(
98+
model,
99+
tokenizer,
100+
prompt_pair,
101+
layer_key="attention_output_layer_8", # Using middle attention layer
102+
highlight_diff=True,
103+
output_dir=os.path.join(OUTPUT_DIR, "pca"),
104+
figure_format="png"
105+
)
106+
print("✅ PCA visualization successful!")
107+
except Exception as e:
108+
print(f"❌ PCA visualization failed: {e}")
109+
# If the specific layer fails, try with layer 0
110+
print("Retrying with layer 0...")
111+
try:
112+
visualize_pca(
113+
model,
114+
tokenizer,
115+
prompt_pair,
116+
layer_key="attention_output_layer_0",
117+
highlight_diff=True,
118+
output_dir=os.path.join(OUTPUT_DIR, "pca"),
119+
figure_format="png"
120+
)
121+
print("✅ PCA visualization with layer 0 successful!")
122+
except Exception as e2:
123+
print(f"❌ PCA visualization with layer 0 also failed: {e2}")
124+
125+
# Test 4: Main visualize_bias function (combines all visualization types)
126+
print("\n" + "="*50)
127+
print("TESTING MAIN VISUALIZATION FUNCTION")
128+
print("="*50)
129+
130+
print("\nTesting visualize_bias function...")
131+
try:
132+
_, metrics = visualize_bias(
133+
model,
134+
tokenizer,
135+
prompt_pairs=prompt_pairs,
136+
visualization_types=["mean_diff", "heatmap", "pca"],
137+
layers="first_middle_last",
138+
output_dir=os.path.join(OUTPUT_DIR, "combined"),
139+
figure_format="png",
140+
show_progress=True
141+
)
142+
143+
print("✅ visualize_bias function successful!")
144+
145+
# Print some metrics to verify they're being calculated correctly
146+
if metrics and "pair_1" in metrics:
147+
overall = metrics["pair_1"]["metrics"]["overall_metrics"]
148+
print("\nMetrics sample:")
149+
print(f" Overall mean difference: {overall['mean_difference']:.6f}")
150+
print(f" Max difference: {overall['max_difference']:.6f}")
151+
152+
# Check if we have component metrics
153+
if "component_metrics" in metrics["pair_1"]["metrics"]:
154+
comp_metrics = metrics["pair_1"]["metrics"]["component_metrics"]
155+
for comp_name, comp_data in comp_metrics.items():
156+
if comp_name in ["mlp_output", "attention_output"]:
157+
print(f" {comp_name} mean difference: {comp_data['mean_difference']:.6f}")
158+
except Exception as e:
159+
print(f"❌ visualize_bias function failed: {e}")
160+
161+
print("\nTests completed. Check the output directory for visualization results:")
162+
print(f" {os.path.abspath(OUTPUT_DIR)}")
163+
164+
except Exception as e:
165+
print(f"Failed to load model: {e}")
166+
print("Please make sure you have access to the specified model and that your environment is set up correctly.")
167+
168+
if __name__ == "__main__":
169+
test_bias_visualizations()

0 commit comments

Comments
 (0)