|
| 1 | +from sackmann import get_data |
| 2 | +import pymc as pm |
| 3 | +from sklearn.preprocessing import LabelEncoder |
| 4 | +import numpy as np |
| 5 | + |
| 6 | + |
| 7 | +def create_arrays( |
| 8 | + start_year=1960, |
| 9 | + data_dir="./tennis_atp", |
| 10 | + include_qualifying_and_challengers=False, |
| 11 | + include_futures=False, |
| 12 | +): |
| 13 | + |
| 14 | + df = get_data( |
| 15 | + data_dir, |
| 16 | + include_qualifying_and_challengers=include_qualifying_and_challengers, |
| 17 | + include_futures=include_futures, |
| 18 | + ) |
| 19 | + |
| 20 | + rel_df = df[df["tourney_date"].dt.year >= start_year] |
| 21 | + |
| 22 | + encoder = LabelEncoder() |
| 23 | + |
| 24 | + encoder.fit( |
| 25 | + rel_df["winner_name"].values.tolist() + rel_df["loser_name"].values.tolist() |
| 26 | + ) |
| 27 | + |
| 28 | + winner_ids = encoder.transform(rel_df["winner_name"]) |
| 29 | + loser_ids = encoder.transform(rel_df["loser_name"]) |
| 30 | + |
| 31 | + return { |
| 32 | + "winner_ids": winner_ids, |
| 33 | + "loser_ids": loser_ids, |
| 34 | + "player_encoder": encoder, |
| 35 | + } |
| 36 | + |
| 37 | + |
| 38 | +def get_pymc_model(start_year=1960, data_dir="./tennis_atp"): |
| 39 | + |
| 40 | + arrays = create_arrays(start_year=start_year, data_dir=data_dir) |
| 41 | + |
| 42 | + n_players = len(arrays["player_encoder"].classes_) |
| 43 | + |
| 44 | + winner_ids = arrays["winner_ids"] |
| 45 | + loser_ids = arrays["loser_ids"] |
| 46 | + |
| 47 | + with pm.Model() as model: |
| 48 | + |
| 49 | + player_sd = pm.HalfNormal("player_sd", sigma=1.0) |
| 50 | + |
| 51 | + player_skills_raw = pm.Normal( |
| 52 | + "player_skills_raw", 0.0, sigma=1.0, shape=(n_players,) |
| 53 | + ) |
| 54 | + |
| 55 | + player_skills = pm.Deterministic("player_skills", player_skills_raw * player_sd) |
| 56 | + logit_skills = player_skills[winner_ids] - player_skills[loser_ids] |
| 57 | + |
| 58 | + lik = pm.Bernoulli( |
| 59 | + "win_lik", logit_p=logit_skills, observed=np.ones(winner_ids.shape[0]) |
| 60 | + ) |
| 61 | + |
| 62 | + return model |
0 commit comments