Skip to content

Commit 25858c2

Browse files
committed
add model training on synthetic data
1 parent af2a128 commit 25858c2

File tree

1 file changed

+125
-7
lines changed

1 file changed

+125
-7
lines changed

notebooks/synthetic_experiments.ipynb

Lines changed: 125 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
"source": [
2626
"import sys\n",
2727
"sys.path.append(\"../\")\n",
28+
"import matplotlib.pyplot as plt\n",
29+
"import numpy as np\n",
2830
"\n",
2931
"import choice_learn\n",
30-
"from python.data import SyntheticDataGenerator"
32+
"from python.data import SyntheticDataGenerator\n",
33+
"from choice_learn.basket_models import Trip, TripDataset"
3134
]
3235
},
3336
{
@@ -68,25 +71,140 @@
6871
"dataset = data_gen.generate_dataset(n_baskets=1000)"
6972
]
7073
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"id": "51791e7e",
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"trip_list = []\n",
82+
"for basket in dataset:\n",
83+
" trip_list.append(Trip(purchases=basket, prices=np.zeros((8, )), assortment=0))\n",
84+
"\n",
85+
"trip_dataset = TripDataset(trips=trip_list, available_items=np.ones((1, 8)))"
86+
]
87+
},
7188
{
7289
"cell_type": "markdown",
73-
"id": "f337217b",
90+
"id": "52b4b18c",
91+
"metadata": {},
92+
"source": [
93+
"## Modelling "
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"id": "3d6c32e2",
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"from choice_learn.basket_models import AleaCarta"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"id": "6ef517b6",
74110
"metadata": {},
111+
"outputs": [],
75112
"source": [
76-
"### Sample purchased baskets\n",
113+
"latent_sizes = {\"preferences\": 6, \"price\": 3, \"season\": 3}\n",
114+
"n_negative_samples = 2\n",
115+
"optimizer = \"adam\"\n",
116+
"lr = 1e-2\n",
117+
"epochs = 200\n",
118+
"batch_size = 32\n",
77119
"\n",
78-
"### Modelling\n",
120+
"model = AleaCarta(\n",
121+
" item_intercept=False,\n",
122+
" price_effects=False,\n",
123+
" seasonal_effects=False,\n",
124+
" latent_sizes=latent_sizes,\n",
125+
" n_negative_samples=n_negative_samples,\n",
126+
" optimizer=optimizer,\n",
127+
" lr=lr,\n",
128+
" epochs=epochs,\n",
129+
" batch_size=batch_size,\n",
130+
")\n",
79131
"\n",
80-
"### Results"
132+
"model.instantiate(n_items=8, n_stores=2)"
81133
]
82134
},
83135
{
84136
"cell_type": "code",
85137
"execution_count": null,
86-
"id": "ba1b8457",
138+
"id": "2f8a915e",
87139
"metadata": {},
88140
"outputs": [],
89-
"source": []
141+
"source": [
142+
"history = model.fit(trip_dataset)"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"id": "1c78ef41",
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"plt.plot(history[\"train_loss\"])\n",
153+
"plt.xlabel(\"Epoch\")\n",
154+
"plt.ylabel(\"Training Loss\")\n",
155+
"plt.legend()\n",
156+
"plt.title(\"Training of Shopper\")\n",
157+
"plt.show()"
158+
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"id": "f337217b",
163+
"metadata": {},
164+
"source": [
165+
"## Results"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"id": "e4008d65",
172+
"metadata": {},
173+
"outputs": [],
174+
"source": [
175+
"import matplotlib.pyplot as plt\n",
176+
"import matplotlib as mpl\n",
177+
"import numpy as np\n",
178+
"\n",
179+
"fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5, 5))\n",
180+
"mask = np.ones((8,8), dtype=bool)\n",
181+
"res = []\n",
182+
"for i in range(8):\n",
183+
" r = model.compute_batch_utility(item_batch=np.array(list(range(8))),\n",
184+
" basket_batch=np.array([[i] for _ in range(8)]) ,\n",
185+
" store_batch=np.array([0, 0, 0, 0, 0, 0, 0, 0]),\n",
186+
" week_batch=np.array([0, 0, 0, 0, 0, 0, 0, 0]),\n",
187+
" price_batch=np.array([[0, 0, 0, 0, 0, 0] for _ in range(8)]))\n",
188+
" m = np.ones(8)\n",
189+
" m[i] = 0\n",
190+
" den = np.exp(r) * m\n",
191+
" r = den / den.sum()\n",
192+
" # r = np.concatenate([tf.nn.softmax(np.concatenate([r[:i], r[i+1:]]))[:i], [.0], tf.nn.softmax(np.concatenate([r[:i], r[i+1:]]))[i:]])\n",
193+
" res.append(r)\n",
194+
" mask[i][i] = False\n",
195+
"\n",
196+
"res = np.stack(res)\n",
197+
"mask = np.ma.masked_where(mask, res)\n",
198+
"\n",
199+
"axes.set_xticks([], [])\n",
200+
"axes.set_yticks([], [])\n",
201+
"im = axes.imshow(np.stack(res), cmap=\"Spectral\", alpha=0.99, vmin=0, vmax=1)\n",
202+
"axes.imshow(mask, cmap=mpl.colors.ListedColormap(['white']), alpha=1)\n",
203+
"\n",
204+
"cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.69])\n",
205+
"fig.colorbar(im, cax=cbar_ax)\n",
206+
"axes.set_title(\"Estimated Conditional Probabilities\")"
207+
]
90208
},
91209
{
92210
"cell_type": "markdown",

0 commit comments

Comments
 (0)