Skip to content

Commit 0c631f4

Browse files
authored
Merge pull request #325 from saitiger/PMI
Solution and Learn for Pointwise Mutual Information
2 parents 36a8c4c + aa02108 commit 0c631f4

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

Problems/111_PMI/Learn.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Pointwise Mutual Information (PMI)
2+
3+
Pointwise Mutual Information (PMI) is a statistical measure used in information theory and Natural Language Processing (NLP) to quantify the level of association between two events. It compares the probability of two events
4+
occurring together versus the probability of them occurring independently. It is commonly used in Natural Language Processing(NLP) and Information Retrieval to find association between two words, feature selection in text classification,
5+
document similarity.
6+
7+
## Implementation
8+
1. **Collect Count Data for event x, y and joint occurence**
9+
10+
2. **Calculate Individual Probabilities**
11+
12+
3. **Calculate Joint Probability**
13+
14+
4. **Final Score : PMI(x,y) = log₂(P(x,y) / (P(x) * P(y)))**
15+
Where:
16+
- P(x,y) is the probability of events x and y occurring together
17+
18+
- P(x) is the probability of event x occurring
19+
20+
- P(y) is the probability of event y occurring
21+
22+
## Interpreting PMI Values
23+
24+
- **Positive PMI**: Events co-occur more than expected by chance
25+
- **Zero PMI**: Events are statistically independent
26+
- **Negative PMI**: Events co-occur less than expected by chance
27+
- **Undefined**: When P(x,y) = 0 (events never co-occur)
28+
29+
## Variants
30+
31+
### 1. Normalized PMI (NPMI)
32+
- Scales PMI to range [-1, 1]
33+
- Easier to compare across different datasets
34+
- Formula: NPMI(x,y) = PMI(x,y) / -log₂(P(x,y))
35+
36+
### 2. Positive PMI (PPMI)
37+
- Sets negative PMI values to zero
38+
- Commonly used in word embedding models
39+
- Formula: PPMI(x,y) = max(PMI(x,y), 0)

Problems/111_PMI/solution.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
3+
def compute_pmi(joint_counts, total_counts_x, total_counts_y, total_samples):
4+
5+
if not all(isinstance(x, (int, float)) for x in [joint_counts, total_counts_x, total_counts_y, total_samples]):
6+
raise ValueError("All inputs must be numeric")
7+
8+
if any(x < 0 for x in [joint_counts, total_counts_x, total_counts_y, total_samples]):
9+
raise ValueError("Counts cannot be negative")
10+
11+
if total_samples == 0:
12+
raise ValueError("Total samples cannot be zero")
13+
14+
if joint_counts > min(total_counts_x, total_counts_y):
15+
raise ValueError("Joint counts cannot exceed individual counts")
16+
17+
if any(x > total_samples for x in [total_counts_x, total_counts_y]):
18+
raise ValueError("Individual counts cannot exceed total samples")
19+
20+
p_x = total_counts_x / total_samples
21+
p_y = total_counts_y / total_samples
22+
p_xy = joint_counts / total_samples
23+
24+
# Handle edge cases where probabilities are zero
25+
if p_xy == 0 or p_x == 0 or p_y == 0:
26+
return float('-inf')
27+
28+
pmi = np.log2(p_xy / (p_x * p_y))
29+
30+
return round(pmi, 3)
31+
32+
def test_pmi():
33+
# Test Case 1: Perfect positive association
34+
joint_counts1 = 100
35+
total_counts_x1 = 100
36+
total_counts_y1 = 100
37+
total_samples1 = 100
38+
expected1 = round(np.log2(1/(1*1)), 3) # Should be 0.0
39+
assert compute_pmi(joint_counts1, total_counts_x1, total_counts_y1, total_samples1) == expected1, "Test Case 1 Failed"
40+
41+
# Test Case 2: Independence
42+
joint_counts2 = 25
43+
total_counts_x2 = 50
44+
total_counts_y2 = 50
45+
total_samples2 = 100
46+
expected2 = round(np.log2((25/100)/((50/100)*(50/100))), 3) # Should be 0.0
47+
assert compute_pmi(joint_counts2, total_counts_x2, total_counts_y2, total_samples2) == expected2, "Test Case 2 Failed"
48+
49+
# Test Case 3: Negative association
50+
joint_counts3 = 10
51+
total_counts_x3 = 50
52+
total_counts_y3 = 50
53+
total_samples3 = 100
54+
expected3 = round(np.log2((10/100)/((50/100)*(50/100))), 3) # Should be negative
55+
assert compute_pmi(joint_counts3, total_counts_x3, total_counts_y3, total_samples3) == expected3, "Test Case 3 Failed"
56+
57+
# Test Case 4: Zero joint occurrence
58+
joint_counts4 = 0
59+
total_counts_x4 = 50
60+
total_counts_y4 = 50
61+
total_samples4 = 100
62+
expected4 = float('-inf')
63+
assert compute_pmi(joint_counts4, total_counts_x4, total_counts_y4, total_samples4) == expected4, "Test Case 4 Failed"
64+
65+
# Test Case 5: Invalid inputs
66+
try:
67+
compute_pmi(-1, 50, 50, 100)
68+
assert False, "Test Case 5 Failed: Should raise ValueError for negative counts"
69+
except ValueError:
70+
pass
71+
72+
print("All Test Cases Passed!")
73+
74+
if __name__ == "__main__":
75+
test_pmi()

0 commit comments

Comments
 (0)