|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +def compute_pmi(joint_counts, total_counts_x, total_counts_y, total_samples): |
| 4 | + |
| 5 | + if not all(isinstance(x, (int, float)) for x in [joint_counts, total_counts_x, total_counts_y, total_samples]): |
| 6 | + raise ValueError("All inputs must be numeric") |
| 7 | + |
| 8 | + if any(x < 0 for x in [joint_counts, total_counts_x, total_counts_y, total_samples]): |
| 9 | + raise ValueError("Counts cannot be negative") |
| 10 | + |
| 11 | + if total_samples == 0: |
| 12 | + raise ValueError("Total samples cannot be zero") |
| 13 | + |
| 14 | + if joint_counts > min(total_counts_x, total_counts_y): |
| 15 | + raise ValueError("Joint counts cannot exceed individual counts") |
| 16 | + |
| 17 | + if any(x > total_samples for x in [total_counts_x, total_counts_y]): |
| 18 | + raise ValueError("Individual counts cannot exceed total samples") |
| 19 | + |
| 20 | + p_x = total_counts_x / total_samples |
| 21 | + p_y = total_counts_y / total_samples |
| 22 | + p_xy = joint_counts / total_samples |
| 23 | + |
| 24 | + # Handle edge cases where probabilities are zero |
| 25 | + if p_xy == 0 or p_x == 0 or p_y == 0: |
| 26 | + return float('-inf') |
| 27 | + |
| 28 | + pmi = np.log2(p_xy / (p_x * p_y)) |
| 29 | + |
| 30 | + return round(pmi, 3) |
| 31 | + |
| 32 | +def test_pmi(): |
| 33 | + # Test Case 1: Perfect positive association |
| 34 | + joint_counts1 = 100 |
| 35 | + total_counts_x1 = 100 |
| 36 | + total_counts_y1 = 100 |
| 37 | + total_samples1 = 100 |
| 38 | + expected1 = round(np.log2(1/(1*1)), 3) # Should be 0.0 |
| 39 | + assert compute_pmi(joint_counts1, total_counts_x1, total_counts_y1, total_samples1) == expected1, "Test Case 1 Failed" |
| 40 | + |
| 41 | + # Test Case 2: Independence |
| 42 | + joint_counts2 = 25 |
| 43 | + total_counts_x2 = 50 |
| 44 | + total_counts_y2 = 50 |
| 45 | + total_samples2 = 100 |
| 46 | + expected2 = round(np.log2((25/100)/((50/100)*(50/100))), 3) # Should be 0.0 |
| 47 | + assert compute_pmi(joint_counts2, total_counts_x2, total_counts_y2, total_samples2) == expected2, "Test Case 2 Failed" |
| 48 | + |
| 49 | + # Test Case 3: Negative association |
| 50 | + joint_counts3 = 10 |
| 51 | + total_counts_x3 = 50 |
| 52 | + total_counts_y3 = 50 |
| 53 | + total_samples3 = 100 |
| 54 | + expected3 = round(np.log2((10/100)/((50/100)*(50/100))), 3) # Should be negative |
| 55 | + assert compute_pmi(joint_counts3, total_counts_x3, total_counts_y3, total_samples3) == expected3, "Test Case 3 Failed" |
| 56 | + |
| 57 | + # Test Case 4: Zero joint occurrence |
| 58 | + joint_counts4 = 0 |
| 59 | + total_counts_x4 = 50 |
| 60 | + total_counts_y4 = 50 |
| 61 | + total_samples4 = 100 |
| 62 | + expected4 = float('-inf') |
| 63 | + assert compute_pmi(joint_counts4, total_counts_x4, total_counts_y4, total_samples4) == expected4, "Test Case 4 Failed" |
| 64 | + |
| 65 | + # Test Case 5: Invalid inputs |
| 66 | + try: |
| 67 | + compute_pmi(-1, 50, 50, 100) |
| 68 | + assert False, "Test Case 5 Failed: Should raise ValueError for negative counts" |
| 69 | + except ValueError: |
| 70 | + pass |
| 71 | + |
| 72 | + print("All Test Cases Passed!") |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + test_pmi() |
0 commit comments