Skip to content

Commit 5da7e9d

Browse files
committed
🧹 Refactor ESR into one function that also works without embeddings provided
1 parent 82a427d commit 5da7e9d

File tree

7 files changed

+547
-1401
lines changed

7 files changed

+547
-1401
lines changed

‎README.md‎

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,51 @@ pip install git+https://github.com/pymc-labs/embeddings-similarity-rating.git
2525
## Quick Start
2626

2727
```python
28-
import numpy as np
2928
import polars as po
30-
from embeddings_similarity_rating import EmbeddingsRater
31-
32-
# Create reference sentences with embeddings
33-
reference_data = po.DataFrame({
34-
'id': ['set1'] * 5,
35-
'int_response': [1, 2, 3, 4, 5],
36-
'sentence': [
37-
"It's very unlikely that I'd buy it.",
38-
"It's unlikely that I'd buy it.",
39-
"I might buy it or not. I don't know.",
40-
"It's somewhat possible I'd buy it.",
41-
"It's possible I'd buy it."
42-
],
43-
'embedding_small': [np.random.rand(384).tolist() for _ in range(5)]
44-
})
45-
46-
# Initialize the rater
47-
rater = EmbeddingsRater(reference_data, embeddings_column='embedding_small')
48-
49-
# Convert LLM response embeddings to probability distributions
50-
llm_responses = np.random.rand(10, 384)
51-
pmfs = rater.get_response_pmfs('set1', llm_responses)
52-
53-
# Get overall survey distribution
29+
import numpy as np
30+
from embeddings_similarity_rating import ResponseRater
31+
32+
# Create example reference sentences dataframe
33+
reference_set_1 = [
34+
"Strongly disagree",
35+
"Disagree",
36+
"Neutral",
37+
"Agree",
38+
"Strongly agree",
39+
]
40+
reference_set_2 = [
41+
"Disagree a lot",
42+
"Kinda disagree",
43+
"Don't know",
44+
"Kinda agree",
45+
"Agree a lot",
46+
]
47+
df = po.DataFrame(
48+
{
49+
"id": ["set1"] * 5 + ["set2"] * 5,
50+
"int_response": [1, 2, 3, 4, 5] * 2,
51+
"sentence": reference_set_1 + reference_set_2,
52+
}
53+
)
54+
55+
# Initialize rater
56+
rater = ResponseRater(df)
57+
58+
# Create some example synthetic consumer responses
59+
llm_responses = ["I totally agree", "Not sure about this", "Completely disagree"]
60+
61+
# Get PMFs for synthetic consumer responses
62+
pmfs = rater.get_response_pmfs(
63+
reference_set_id="set1", # Reference set to score against, or "mean"
64+
llm_responses=llm_responses, # List of LLM responses to score
65+
temperature=1.0, # Temperature for scaling the PMF
66+
epsilon=0.0, # Small regularization parameter to prevent division by zero and add smoothing
67+
)
68+
69+
# Get survey response PMF
5470
survey_pmf = rater.get_survey_response_pmf(pmfs)
55-
print(f"Survey distribution: {survey_pmf}")
71+
72+
print(survey_pmf)
5673
```
5774

5875
## Methodology
@@ -65,9 +82,8 @@ The ESR methodology works by:
6582

6683
## Core Components
6784

68-
- `EmbeddingsRater`: Main class implementing the ESR methodology
69-
- `response_embeddings_to_pmf()`: Core function for similarity-to-probability conversion
70-
- `scale_pmf()`: Temperature scaling function
85+
- `ResponseRater`: Main class implementing the ESR methodology
86+
- `get_response_pmfs()`: Convert LLM response embeddings to PMFs using specified reference set
7187

7288
## Citation
7389

‎embeddings_similarity_rating/__init__.py‎

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111
from beartype.claw import beartype_this_package
1212

1313
from .compute import response_embeddings_to_pmf, scale_pmf
14-
from .embeddings_rater import EmbeddingsRater
1514
from .response_rater import ResponseRater
1615

1716
__version__ = "1.0.0"
1817
__author__ = "Ben F. Maier, Ulf Aslak"
1918

2019
__all__ = [
21-
"EmbeddingsRater",
2220
"ResponseRater",
2321
"response_embeddings_to_pmf",
2422
"scale_pmf",

‎embeddings_similarity_rating/embeddings_rater.py‎

Lines changed: 0 additions & 219 deletions
This file was deleted.

0 commit comments

Comments
 (0)