Skip to content
388 changes: 388 additions & 0 deletions identify_linear_expression.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,388 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BK5As0cbUejz",
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"import random\n",
"\n",
"from etuples import etuple\n",
"from unification import unify, var\n",
"\n",
"import pytensor.tensor as pt\n",
"from pytensor.graph import rewrite_graph\n",
"from pytensor.graph.fg import FunctionGraph\n",
"from pytensor.graph.rewriting.basic import MergeOptimizer, PatternNodeRewriter, out2in"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.438328768Z",
"start_time": "2025-08-14T11:29:54.500174Z"
},
"id": "alNycwOIUzTM"
},
"outputs": [],
"source": [
"def find_optimal_P(P, Q, mc):\n",
" pi = (Q * (P - mc)).sum()\n",
" dpi_dP = pt.grad(pi, P)\n",
" # P_star, success = root(dpi_dP, P, method=\"hybr\", optimizer_kwargs=dict(tol=1e-8))\n",
" # return P_star, success\n",
" return dpi_dP"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.440094174Z",
"start_time": "2025-08-14T11:31:54.469010Z"
},
"id": "wVnYGz8GVKb4"
},
"outputs": [],
"source": [
"price_effect = pt.scalar(\"price_effect\")\n",
"price = pt.vector(\"price\")\n",
"trend = pt.vector(\"trend\")\n",
"seasonality = pt.vector(\"seasonality\")\n",
"mc = pt.scalar(\"marginal_cost\")\n",
"\n",
"price_term = price * price_effect\n",
"expected_sales = trend + price_term + seasonality"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.440827348Z",
"start_time": "2025-08-14T11:31:54.681476Z"
},
"id": "BeitshYMVkQU"
},
"outputs": [],
"source": [
"expr = find_optimal_P(price, expected_sales, mc=mc)"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.443902007Z",
"start_time": "2025-08-14T11:31:54.918556Z"
},
"id": "jugOxL4DcRFN"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add [id A] 5\n",
" ├─ Mul [id B] 4\n",
" │ ├─ Sub [id C] 3\n",
" │ │ ├─ price [id D]\n",
" │ │ └─ ExpandDims{axis=0} [id E] 2\n",
" │ │ └─ marginal_cost [id F]\n",
" │ └─ ExpandDims{axis=0} [id G] 0\n",
" │ └─ price_effect [id H]\n",
" ├─ trend [id I]\n",
" ├─ Mul [id J] 1\n",
" │ ├─ price [id D]\n",
" │ └─ ExpandDims{axis=0} [id G] 0\n",
" │ └─ ···\n",
" └─ seasonality [id K]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7fcbccd613c0>"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Use existing rewrites to simplify expression\n",
"fgraph = FunctionGraph(outputs=[expr], clone=False)\n",
"rewrite_graph(fgraph, include=(\"canonicalize\",))\n",
"fgraph.dprint()"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.445406846Z",
"start_time": "2025-08-14T11:31:55.243098Z"
},
"id": "86-KeCOFWQZU"
},
"outputs": [],
"source": [
"# distribute_mul_over_add = PatternNodeRewriter(\n",
"# (pt.mul, (pt.add, \"x\", \"y\"), \"z\"),\n",
"# (pt.add, (pt.mul, \"x\", \"z\"), (pt.mul, \"y\", \"z\")),\n",
"# )\n",
"\n",
"distribute_mul_over_sub = PatternNodeRewriter(\n",
" (pt.mul, (pt.sub, \"x\", \"y\"), \"z\"),\n",
" (pt.add, (pt.mul, \"x\", \"z\"), (pt.mul, (pt.neg, \"y\"), \"z\")),\n",
")\n",
"\n",
"combine_addition_terms = PatternNodeRewriter(\n",
" (pt.add, (pt.add, \"x\", \"y\"), \"z\", \"x\", \"w\"),\n",
" (pt.add, (pt.mul, \"x\", 2), (pt.add, \"y\", \"z\", \"w\")),\n",
")\n",
"\n",
"# distribute_mul_over_add = out2in(distribute_mul_over_add, name=\"distribute_mul_add\")\n",
"distribute_mul_over_sub = out2in(distribute_mul_over_sub, name=\"distribute_mul_sub\")\n",
"combine_addition_terms = out2in(combine_addition_terms, name=\"combine_addition_terms\")\n",
"\n",
"# distribute\n",
"distribute_mul_over_sub.rewrite(fgraph)\n",
"# merge equivalent terms\n",
"MergeOptimizer().rewrite(fgraph)\n",
"# combine equivalent terms\n",
"combine_addition_terms.rewrite(fgraph)\n",
"# extract rewritten expression\n",
"expr = fgraph.outputs[0]"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.446341276Z",
"start_time": "2025-08-14T11:31:56.276558Z"
},
"id": "4qGBap72Xvvn"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add [id A]\n",
" ├─ Mul [id B]\n",
" │ ├─ Mul [id C]\n",
" │ │ ├─ price [id D]\n",
" │ │ └─ ExpandDims{axis=0} [id E]\n",
" │ │ └─ price_effect [id F]\n",
" │ └─ ExpandDims{axis=0} [id G]\n",
" │ └─ 2 [id H]\n",
" └─ Add [id I]\n",
" ├─ Mul [id J]\n",
" │ ├─ Neg [id K]\n",
" │ │ └─ ExpandDims{axis=0} [id L]\n",
" │ │ └─ marginal_cost [id M]\n",
" │ └─ ExpandDims{axis=0} [id E]\n",
" │ └─ ···\n",
" ├─ trend [id N]\n",
" └─ seasonality [id O]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7fcbccd613c0>"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expr.dprint()"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.447033733Z",
"start_time": "2025-08-14T11:31:59.481064Z"
},
"id": "8Fq10k2LcCY-"
},
"outputs": [],
"source": [
"# Create variations of a graph for pattern matching\n",
"rewrites = [\n",
" out2in(\n",
" PatternNodeRewriter((pt.add, \"x\", \"y\"), (pt.add, \"y\", \"x\")),\n",
" name=\"commutative_add\",\n",
" ignore_newtrees=True,\n",
" ),\n",
" out2in(\n",
" PatternNodeRewriter((pt.mul, \"x\", \"y\"), (pt.mul, \"y\", \"x\")),\n",
" name=\"commutative_mul\",\n",
" ignore_newtrees=True,\n",
" ),\n",
" out2in(\n",
" PatternNodeRewriter(\n",
" (pt.mul, (pt.mul, \"x\", \"y\"), \"z\"), (pt.mul, \"x\", (pt.mul, \"y\", \"z\"))\n",
" ),\n",
" name=\"associative_mul\",\n",
" ignore_newtrees=True,\n",
" ),\n",
"]\n",
"\n",
"\n",
"def yield_arithmetic_variants(expr, n):\n",
" fgraph = FunctionGraph(outputs=[expr], clone=False)\n",
" while n > 0:\n",
" rewrite = random.choice(rewrites)\n",
" res = rewrite.apply(fgraph)\n",
" n -= 1\n",
" if res:\n",
" yield fgraph.outputs[0]\n",
" yield fgraph.outputs[0]"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.447578804Z",
"start_time": "2025-08-14T11:31:59.831774Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 198
},
"id": "h9K70LGxYJ7E",
"outputId": "793e98c6-4570-43bf-a452-eb6d0d745dc7"
},
"outputs": [
{
"data": {
"text/plain": [
"{~price: price, ~a: Mul.0, ~b: Add.0}"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Rewrite graph randomly until we match price * a + b\n",
"a, b, price_ = var(\"a\"), var(\"b\"), var(\"price\")\n",
"pattern = etuple(pt.add, etuple(pt.mul, price_, a), b)\n",
"\n",
"for variant in yield_arithmetic_variants(expr, n=100):\n",
" match_dict = unify(variant, pattern)\n",
" if match_dict and match_dict[price_] is price:\n",
" break\n",
"else:\n",
" raise ValueError(\"No matching variant found\")\n",
"match_dict"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.448905279Z",
"start_time": "2025-08-14T11:32:01.264784Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8M-qjXBKa6Db",
"outputId": "cdce40c4-e1dd-4757-f4d6-f368643bb5c1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True_div [id A]\n",
" ├─ Neg [id B]\n",
" │ └─ Add [id C]\n",
" │ ├─ Mul [id D]\n",
" │ │ ├─ Neg [id E]\n",
" │ │ │ └─ ExpandDims{axis=0} [id F]\n",
" │ │ │ └─ marginal_cost [id G]\n",
" │ │ └─ ExpandDims{axis=0} [id H]\n",
" │ │ └─ price_effect [id I]\n",
" │ ├─ trend [id J]\n",
" │ └─ seasonality [id K]\n",
" └─ Mul [id L]\n",
" ├─ ExpandDims{axis=0} [id H]\n",
" │ └─ ···\n",
" └─ ExpandDims{axis=0} [id M]\n",
" └─ 2 [id N]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7fcbccd613c0>"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimal_result = -match_dict[b] / match_dict[a]\n",
"optimal_result.dprint()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2025-08-14T11:32:09.449645675Z",
"start_time": "2025-08-14T11:25:52.269957Z"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
2 changes: 1 addition & 1 deletion pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def c_code_cache_version(self):
return v


softplus = Softplus(upgrade_to_float, name="scalar_softplus")
softplus = Softplus(upgrade_to_float, name="softplus")


class Log1mexp(UnaryScalarOp):
Expand Down
Loading
Loading