Skip to content

Commit 4d363c7

Browse files
Merge branch 'dev' into develop
2 parents 319ca7b + 25858c2 commit 4d363c7

File tree

9 files changed

+318
-2
lines changed

9 files changed

+318
-2
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,4 @@ cython_debug/
173173
# PyPI configuration file
174174
.pypirc
175175

176-
.DS_Store
176+
.DS_Store

doc/laurel_carta.png

1.04 MB
Loading

doc/resized_img.png

2.86 KB
Loading

doc/rs_img.png

89.7 KB
Loading

doc/rs_rs_img.png

1.39 KB
Loading

doc/rs_rs_rs_img.png

1.39 KB
Loading
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "60b9cc4f",
6+
"metadata": {},
7+
"source": [
8+
"# Synthetic Experiments"
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "984f8af4",
14+
"metadata": {},
15+
"source": [
16+
"## Sample synthetic data"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"id": "d8f8d31b",
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"import sys\n",
27+
"sys.path.append(\"../\")\n",
28+
"import matplotlib.pyplot as plt\n",
29+
"import numpy as np\n",
30+
"\n",
31+
"import choice_learn\n",
32+
"from python.data import SyntheticDataGenerator\n",
33+
"from choice_learn.basket_models import Trip, TripDataset"
34+
]
35+
},
36+
{
37+
"cell_type": "markdown",
38+
"id": "b3024007",
39+
"metadata": {},
40+
"source": [
41+
"## Sample purchased baskets"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"id": "ed8a74e6",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"items_nests = {0:[0, 1, 2],\n",
52+
"1: [3, 4, 5],\n",
53+
"2: [6],\n",
54+
"3: [7]}\n",
55+
"\n",
56+
"nests_interactions = [[\"\", \"compl\", \"neutral\", \"neutral\"],\n",
57+
"[\"compl\", \"\", \"neutral\", \"neutral\"],\n",
58+
"[\"neutral\", \"neutral\", \"\", \"neutral\"],\n",
59+
"[\"neutral\", \"neutral\", \"neutral\", \"\"]]\n",
60+
"\n",
61+
"data_gen = SyntheticDataGenerator(items_nest=items_nests, nests_interactions=nests_interactions)"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"id": "9c3f06eb",
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"dataset = data_gen.generate_dataset(n_baskets=1000)"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"id": "51791e7e",
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"trip_list = []\n",
82+
"for basket in dataset:\n",
83+
" trip_list.append(Trip(purchases=basket, prices=np.zeros((8, )), assortment=0))\n",
84+
"\n",
85+
"trip_dataset = TripDataset(trips=trip_list, available_items=np.ones((1, 8)))"
86+
]
87+
},
88+
{
89+
"cell_type": "markdown",
90+
"id": "52b4b18c",
91+
"metadata": {},
92+
"source": [
93+
"## Modelling "
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"id": "3d6c32e2",
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"from choice_learn.basket_models import AleaCarta"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"id": "6ef517b6",
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"latent_sizes = {\"preferences\": 6, \"price\": 3, \"season\": 3}\n",
114+
"n_negative_samples = 2\n",
115+
"optimizer = \"adam\"\n",
116+
"lr = 1e-2\n",
117+
"epochs = 200\n",
118+
"batch_size = 32\n",
119+
"\n",
120+
"model = AleaCarta(\n",
121+
" item_intercept=False,\n",
122+
" price_effects=False,\n",
123+
" seasonal_effects=False,\n",
124+
" latent_sizes=latent_sizes,\n",
125+
" n_negative_samples=n_negative_samples,\n",
126+
" optimizer=optimizer,\n",
127+
" lr=lr,\n",
128+
" epochs=epochs,\n",
129+
" batch_size=batch_size,\n",
130+
")\n",
131+
"\n",
132+
"model.instantiate(n_items=8, n_stores=2)"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"id": "2f8a915e",
139+
"metadata": {},
140+
"outputs": [],
141+
"source": [
142+
"history = model.fit(trip_dataset)"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"id": "1c78ef41",
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"plt.plot(history[\"train_loss\"])\n",
153+
"plt.xlabel(\"Epoch\")\n",
154+
"plt.ylabel(\"Training Loss\")\n",
155+
"plt.legend()\n",
156+
"plt.title(\"Training of Shopper\")\n",
157+
"plt.show()"
158+
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"id": "f337217b",
163+
"metadata": {},
164+
"source": [
165+
"## Results"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"id": "e4008d65",
172+
"metadata": {},
173+
"outputs": [],
174+
"source": [
175+
"import matplotlib.pyplot as plt\n",
176+
"import matplotlib as mpl\n",
177+
"import numpy as np\n",
178+
"\n",
179+
"fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5, 5))\n",
180+
"mask = np.ones((8,8), dtype=bool)\n",
181+
"res = []\n",
182+
"for i in range(8):\n",
183+
" r = model.compute_batch_utility(item_batch=np.array(list(range(8))),\n",
184+
" basket_batch=np.array([[i] for _ in range(8)]) ,\n",
185+
" store_batch=np.array([0, 0, 0, 0, 0, 0, 0, 0]),\n",
186+
" week_batch=np.array([0, 0, 0, 0, 0, 0, 0, 0]),\n",
187+
" price_batch=np.array([[0, 0, 0, 0, 0, 0] for _ in range(8)]))\n",
188+
" m = np.ones(8)\n",
189+
" m[i] = 0\n",
190+
" den = np.exp(r) * m\n",
191+
" r = den / den.sum()\n",
192+
" # r = np.concatenate([tf.nn.softmax(np.concatenate([r[:i], r[i+1:]]))[:i], [.0], tf.nn.softmax(np.concatenate([r[:i], r[i+1:]]))[i:]])\n",
193+
" res.append(r)\n",
194+
" mask[i][i] = False\n",
195+
"\n",
196+
"res = np.stack(res)\n",
197+
"mask = np.ma.masked_where(mask, res)\n",
198+
"\n",
199+
"axes.set_xticks([], [])\n",
200+
"axes.set_yticks([], [])\n",
201+
"im = axes.imshow(np.stack(res), cmap=\"Spectral\", alpha=0.99, vmin=0, vmax=1)\n",
202+
"axes.imshow(mask, cmap=mpl.colors.ListedColormap(['white']), alpha=1)\n",
203+
"\n",
204+
"cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.69])\n",
205+
"fig.colorbar(im, cax=cbar_ax)\n",
206+
"axes.set_title(\"Estimated Conditional Probabilities\")"
207+
]
208+
},
209+
{
210+
"cell_type": "markdown",
211+
"id": "1089cdb5",
212+
"metadata": {},
213+
"source": []
214+
}
215+
],
216+
"metadata": {
217+
"kernelspec": {
218+
"display_name": "with_choice_learn",
219+
"language": "python",
220+
"name": "python3"
221+
},
222+
"language_info": {
223+
"codemirror_mode": {
224+
"name": "ipython",
225+
"version": 3
226+
},
227+
"file_extension": ".py",
228+
"mimetype": "text/x-python",
229+
"name": "python",
230+
"nbconvert_exporter": "python",
231+
"pygments_lexer": "ipython3",
232+
"version": "3.12.11"
233+
}
234+
},
235+
"nbformat": 4,
236+
"nbformat_minor": 5
237+
}

python/data.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Data generation related stuff."""
2+
3+
import numpy as np
4+
from tqdm import trange
5+
6+
7+
class SyntheticDataGenerator:
8+
def __init__(
9+
self,
10+
items_nest: dict, # keys should be integer: the nest number
11+
nests_interactions: list,
12+
proba_complementary_items: float = 0.7,
13+
proba_neutral_items: float = 0.15,
14+
noise_proba: float = 0.05,
15+
) -> None:
16+
17+
self.proba_complementary_items = proba_complementary_items
18+
self.proba_neutral_items = proba_neutral_items
19+
self.noise_proba = noise_proba
20+
21+
22+
self.items_nest = items_nest
23+
self.nests_interactions = nests_interactions
24+
25+
def generate_basket(self) -> list:
26+
"""Generates a basket of items based on the defined item sets and their relations."""
27+
28+
29+
def select_first_item() -> tuple:
30+
"""Selects the first item and its nest randomly from the available sets."""
31+
32+
chosen_nest = np.random.choice(list(self.items_nest.keys()))
33+
chosen_item = np.random.choice(list(self.items_nest[chosen_nest]))
34+
return chosen_item, chosen_nest
35+
36+
def complete_basket(first_item: int, first_nest: str) -> list:
37+
"""Completes the basket by adding items based on the relations of the first item."""
38+
basket = [first_item]
39+
relations = self.nests_interactions[first_nest]
40+
for nest_id, items in self.items_nest.items():
41+
if (
42+
relations[nest_id] == "compl"
43+
and np.random.random() < self.proba_complementary_items
44+
):
45+
basket.append(np.random.choice(items))
46+
elif (
47+
relations[nest_id] == "neutral"
48+
and np.random.random() < self.proba_neutral_items
49+
):
50+
basket.append(np.random.choice(items))
51+
return basket
52+
53+
def add_noise(basket: list) -> list:
54+
"""Adds noise items to the basket based on the defined noise probability."""
55+
if np.random.random() < self.noise_proba:
56+
possible_noisy_items = []
57+
for nest, items in self.items_nest.items():
58+
for item in items:
59+
if item not in basket:
60+
possible_noisy_items.append(item)
61+
if len(possible_noisy_items) > 0:
62+
basket.append(np.random.choice(possible_noisy_items))
63+
return basket
64+
65+
first_chosen_item, first_chosen_nest = select_first_item()
66+
basket = complete_basket(first_item=first_chosen_item, first_nest=first_chosen_nest)
67+
basket = add_noise(basket)
68+
69+
return basket
70+
71+
def generate_dataset(self, n_baskets) -> np.ndarray:
72+
"""Generates a dataset of baskets."""
73+
74+
baskets = []
75+
for _ in range(n_baskets):
76+
baskets.append(self.generate_basket())
77+
return np.array(baskets, dtype=object)
78+
79+

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
choice-learn
2-
matplotlib
2+
matplotlib

0 commit comments

Comments
 (0)