Skip to content

Commit 07fcc1e

Browse files
progress
1 parent 2730892 commit 07fcc1e

File tree

2 files changed

+734
-79
lines changed

2 files changed

+734
-79
lines changed

identify_linear_expression.ipynb

Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"id": "BK5As0cbUejz",
8+
"jupyter": {
9+
"is_executing": true
10+
}
11+
},
12+
"outputs": [],
13+
"source": [
14+
"import random\n",
15+
"\n",
16+
"from etuples import etuple\n",
17+
"from unification import unify, var\n",
18+
"\n",
19+
"import pytensor.tensor as pt\n",
20+
"from pytensor.graph import rewrite_graph\n",
21+
"from pytensor.graph.fg import FunctionGraph\n",
22+
"from pytensor.graph.rewriting.basic import MergeOptimizer, PatternNodeRewriter, out2in"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 87,
28+
"metadata": {
29+
"ExecuteTime": {
30+
"end_time": "2025-08-14T11:32:09.438328768Z",
31+
"start_time": "2025-08-14T11:29:54.500174Z"
32+
},
33+
"id": "alNycwOIUzTM"
34+
},
35+
"outputs": [],
36+
"source": [
37+
"def find_optimal_P(P, Q, mc):\n",
38+
" pi = (Q * (P - mc)).sum()\n",
39+
" dpi_dP = pt.grad(pi, P)\n",
40+
" # P_star, success = root(dpi_dP, P, method=\"hybr\", optimizer_kwargs=dict(tol=1e-8))\n",
41+
" # return P_star, success\n",
42+
" return dpi_dP"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": 97,
48+
"metadata": {
49+
"ExecuteTime": {
50+
"end_time": "2025-08-14T11:32:09.440094174Z",
51+
"start_time": "2025-08-14T11:31:54.469010Z"
52+
},
53+
"id": "wVnYGz8GVKb4"
54+
},
55+
"outputs": [],
56+
"source": [
57+
"price_effect = pt.scalar(\"price_effect\")\n",
58+
"price = pt.vector(\"price\")\n",
59+
"trend = pt.vector(\"trend\")\n",
60+
"seasonality = pt.vector(\"seasonality\")\n",
61+
"mc = pt.scalar(\"marginal_cost\")\n",
62+
"\n",
63+
"price_term = price * price_effect\n",
64+
"expected_sales = trend + price_term + seasonality"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": 98,
70+
"metadata": {
71+
"ExecuteTime": {
72+
"end_time": "2025-08-14T11:32:09.440827348Z",
73+
"start_time": "2025-08-14T11:31:54.681476Z"
74+
},
75+
"id": "BeitshYMVkQU"
76+
},
77+
"outputs": [],
78+
"source": [
79+
"expr = find_optimal_P(price, expected_sales, mc=mc)"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": 99,
85+
"metadata": {
86+
"ExecuteTime": {
87+
"end_time": "2025-08-14T11:32:09.443902007Z",
88+
"start_time": "2025-08-14T11:31:54.918556Z"
89+
},
90+
"id": "jugOxL4DcRFN"
91+
},
92+
"outputs": [
93+
{
94+
"name": "stdout",
95+
"output_type": "stream",
96+
"text": [
97+
"Add [id A] 5\n",
98+
" ├─ Mul [id B] 4\n",
99+
" │ ├─ Sub [id C] 3\n",
100+
" │ │ ├─ price [id D]\n",
101+
" │ │ └─ ExpandDims{axis=0} [id E] 2\n",
102+
" │ │ └─ marginal_cost [id F]\n",
103+
" │ └─ ExpandDims{axis=0} [id G] 0\n",
104+
" │ └─ price_effect [id H]\n",
105+
" ├─ trend [id I]\n",
106+
" ├─ Mul [id J] 1\n",
107+
" │ ├─ price [id D]\n",
108+
" │ └─ ExpandDims{axis=0} [id G] 0\n",
109+
" │ └─ ···\n",
110+
" └─ seasonality [id K]\n"
111+
]
112+
},
113+
{
114+
"data": {
115+
"text/plain": [
116+
"<ipykernel.iostream.OutStream at 0x7fcbccd613c0>"
117+
]
118+
},
119+
"execution_count": 99,
120+
"metadata": {},
121+
"output_type": "execute_result"
122+
}
123+
],
124+
"source": [
125+
"# Use existing rewrites to simplify expression\n",
126+
"fgraph = FunctionGraph(outputs=[expr], clone=False)\n",
127+
"rewrite_graph(fgraph, include=(\"canonicalize\",))\n",
128+
"fgraph.dprint()"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 100,
134+
"metadata": {
135+
"ExecuteTime": {
136+
"end_time": "2025-08-14T11:32:09.445406846Z",
137+
"start_time": "2025-08-14T11:31:55.243098Z"
138+
},
139+
"id": "86-KeCOFWQZU"
140+
},
141+
"outputs": [],
142+
"source": [
143+
"# distribute_mul_over_add = PatternNodeRewriter(\n",
144+
"# (pt.mul, (pt.add, \"x\", \"y\"), \"z\"),\n",
145+
"# (pt.add, (pt.mul, \"x\", \"z\"), (pt.mul, \"y\", \"z\")),\n",
146+
"# )\n",
147+
"\n",
148+
"distribute_mul_over_sub = PatternNodeRewriter(\n",
149+
" (pt.mul, (pt.sub, \"x\", \"y\"), \"z\"),\n",
150+
" (pt.add, (pt.mul, \"x\", \"z\"), (pt.mul, (pt.neg, \"y\"), \"z\")),\n",
151+
")\n",
152+
"\n",
153+
"combine_addition_terms = PatternNodeRewriter(\n",
154+
" (pt.add, (pt.add, \"x\", \"y\"), \"z\", \"x\", \"w\"),\n",
155+
" (pt.add, (pt.mul, \"x\", 2), (pt.add, \"y\", \"z\", \"w\")),\n",
156+
")\n",
157+
"\n",
158+
"# distribute_mul_over_add = out2in(distribute_mul_over_add, name=\"distribute_mul_add\")\n",
159+
"distribute_mul_over_sub = out2in(distribute_mul_over_sub, name=\"distribute_mul_sub\")\n",
160+
"combine_addition_terms = out2in(combine_addition_terms, name=\"combine_addition_terms\")\n",
161+
"\n",
162+
"# distribute\n",
163+
"distribute_mul_over_sub.rewrite(fgraph)\n",
164+
"# merge equivalent terms\n",
165+
"MergeOptimizer().rewrite(fgraph)\n",
166+
"# combine equivalent terms\n",
167+
"combine_addition_terms.rewrite(fgraph)\n",
168+
"# extract rewritten expression\n",
169+
"expr = fgraph.outputs[0]"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": 101,
175+
"metadata": {
176+
"ExecuteTime": {
177+
"end_time": "2025-08-14T11:32:09.446341276Z",
178+
"start_time": "2025-08-14T11:31:56.276558Z"
179+
},
180+
"id": "4qGBap72Xvvn"
181+
},
182+
"outputs": [
183+
{
184+
"name": "stdout",
185+
"output_type": "stream",
186+
"text": [
187+
"Add [id A]\n",
188+
" ├─ Mul [id B]\n",
189+
" │ ├─ Mul [id C]\n",
190+
" │ │ ├─ price [id D]\n",
191+
" │ │ └─ ExpandDims{axis=0} [id E]\n",
192+
" │ │ └─ price_effect [id F]\n",
193+
" │ └─ ExpandDims{axis=0} [id G]\n",
194+
" │ └─ 2 [id H]\n",
195+
" └─ Add [id I]\n",
196+
" ├─ Mul [id J]\n",
197+
" │ ├─ Neg [id K]\n",
198+
" │ │ └─ ExpandDims{axis=0} [id L]\n",
199+
" │ │ └─ marginal_cost [id M]\n",
200+
" │ └─ ExpandDims{axis=0} [id E]\n",
201+
" │ └─ ···\n",
202+
" ├─ trend [id N]\n",
203+
" └─ seasonality [id O]\n"
204+
]
205+
},
206+
{
207+
"data": {
208+
"text/plain": [
209+
"<ipykernel.iostream.OutStream at 0x7fcbccd613c0>"
210+
]
211+
},
212+
"execution_count": 101,
213+
"metadata": {},
214+
"output_type": "execute_result"
215+
}
216+
],
217+
"source": [
218+
"expr.dprint()"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": 102,
224+
"metadata": {
225+
"ExecuteTime": {
226+
"end_time": "2025-08-14T11:32:09.447033733Z",
227+
"start_time": "2025-08-14T11:31:59.481064Z"
228+
},
229+
"id": "8Fq10k2LcCY-"
230+
},
231+
"outputs": [],
232+
"source": [
233+
"# Create variations of a graph for pattern matching\n",
234+
"rewrites = [\n",
235+
" out2in(\n",
236+
" PatternNodeRewriter((pt.add, \"x\", \"y\"), (pt.add, \"y\", \"x\")),\n",
237+
" name=\"commutative_add\",\n",
238+
" ignore_newtrees=True,\n",
239+
" ),\n",
240+
" out2in(\n",
241+
" PatternNodeRewriter((pt.mul, \"x\", \"y\"), (pt.mul, \"y\", \"x\")),\n",
242+
" name=\"commutative_mul\",\n",
243+
" ignore_newtrees=True,\n",
244+
" ),\n",
245+
" out2in(\n",
246+
" PatternNodeRewriter(\n",
247+
" (pt.mul, (pt.mul, \"x\", \"y\"), \"z\"), (pt.mul, \"x\", (pt.mul, \"y\", \"z\"))\n",
248+
" ),\n",
249+
" name=\"associative_mul\",\n",
250+
" ignore_newtrees=True,\n",
251+
" ),\n",
252+
"]\n",
253+
"\n",
254+
"\n",
255+
"def yield_arithmetic_variants(expr, n):\n",
256+
" fgraph = FunctionGraph(outputs=[expr], clone=False)\n",
257+
" while n > 0:\n",
258+
" rewrite = random.choice(rewrites)\n",
259+
" res = rewrite.apply(fgraph)\n",
260+
" n -= 1\n",
261+
" if res:\n",
262+
" yield fgraph.outputs[0]\n",
263+
" yield fgraph.outputs[0]"
264+
]
265+
},
266+
{
267+
"cell_type": "code",
268+
"execution_count": 103,
269+
"metadata": {
270+
"ExecuteTime": {
271+
"end_time": "2025-08-14T11:32:09.447578804Z",
272+
"start_time": "2025-08-14T11:31:59.831774Z"
273+
},
274+
"colab": {
275+
"base_uri": "https://localhost:8080/",
276+
"height": 198
277+
},
278+
"id": "h9K70LGxYJ7E",
279+
"outputId": "793e98c6-4570-43bf-a452-eb6d0d745dc7"
280+
},
281+
"outputs": [
282+
{
283+
"data": {
284+
"text/plain": [
285+
"{~price: price, ~a: Mul.0, ~b: Add.0}"
286+
]
287+
},
288+
"execution_count": 103,
289+
"metadata": {},
290+
"output_type": "execute_result"
291+
}
292+
],
293+
"source": [
294+
"# Rewrite graph randomly until we match price * a + b\n",
295+
"a, b, price_ = var(\"a\"), var(\"b\"), var(\"price\")\n",
296+
"pattern = etuple(pt.add, etuple(pt.mul, price_, a), b)\n",
297+
"\n",
298+
"for variant in yield_arithmetic_variants(expr, n=100):\n",
299+
" match_dict = unify(variant, pattern)\n",
300+
" if match_dict and match_dict[price_] is price:\n",
301+
" break\n",
302+
"else:\n",
303+
" raise ValueError(\"No matching variant found\")\n",
304+
"match_dict"
305+
]
306+
},
307+
{
308+
"cell_type": "code",
309+
"execution_count": 104,
310+
"metadata": {
311+
"ExecuteTime": {
312+
"end_time": "2025-08-14T11:32:09.448905279Z",
313+
"start_time": "2025-08-14T11:32:01.264784Z"
314+
},
315+
"colab": {
316+
"base_uri": "https://localhost:8080/"
317+
},
318+
"id": "8M-qjXBKa6Db",
319+
"outputId": "cdce40c4-e1dd-4757-f4d6-f368643bb5c1"
320+
},
321+
"outputs": [
322+
{
323+
"name": "stdout",
324+
"output_type": "stream",
325+
"text": [
326+
"True_div [id A]\n",
327+
" ├─ Neg [id B]\n",
328+
" │ └─ Add [id C]\n",
329+
" │ ├─ Mul [id D]\n",
330+
" │ │ ├─ Neg [id E]\n",
331+
" │ │ │ └─ ExpandDims{axis=0} [id F]\n",
332+
" │ │ │ └─ marginal_cost [id G]\n",
333+
" │ │ └─ ExpandDims{axis=0} [id H]\n",
334+
" │ │ └─ price_effect [id I]\n",
335+
" │ ├─ trend [id J]\n",
336+
" │ └─ seasonality [id K]\n",
337+
" └─ Mul [id L]\n",
338+
" ├─ ExpandDims{axis=0} [id H]\n",
339+
" │ └─ ···\n",
340+
" └─ ExpandDims{axis=0} [id M]\n",
341+
" └─ 2 [id N]\n"
342+
]
343+
},
344+
{
345+
"data": {
346+
"text/plain": [
347+
"<ipykernel.iostream.OutStream at 0x7fcbccd613c0>"
348+
]
349+
},
350+
"execution_count": 104,
351+
"metadata": {},
352+
"output_type": "execute_result"
353+
}
354+
],
355+
"source": [
356+
"optimal_result = -match_dict[b] / match_dict[a]\n",
357+
"optimal_result.dprint()"
358+
]
359+
},
360+
{
361+
"cell_type": "code",
362+
"execution_count": null,
363+
"metadata": {
364+
"ExecuteTime": {
365+
"end_time": "2025-08-14T11:32:09.449645675Z",
366+
"start_time": "2025-08-14T11:25:52.269957Z"
367+
}
368+
},
369+
"outputs": [],
370+
"source": []
371+
}
372+
],
373+
"metadata": {
374+
"colab": {
375+
"provenance": []
376+
},
377+
"kernelspec": {
378+
"display_name": "Python 3 (ipykernel)",
379+
"language": "python",
380+
"name": "python3"
381+
},
382+
"language_info": {
383+
"name": "python"
384+
}
385+
},
386+
"nbformat": 4,
387+
"nbformat_minor": 0
388+
}

0 commit comments

Comments
 (0)