Skip to content

Commit c843d2f

Browse files
authored
SymPy > 1.12.1 yield wrong results in PSATD symbolic notebooks (#5968)
I recently reused the SymPy notebooks that were added in #3456 and #4316, and I noticed that versions of SymPy higher than 1.12.1 exhibit buggy calculations and yield wrong results (e.g., wrong coefficients for the equations in #4316). This PR implements the following changes: - Force `sympy<=1.12.1` in the notebooks, raise exception otherwise. - Raise exceptions whenever the symbolic solutions cannot be verified.
1 parent 89a8034 commit c843d2f

File tree

3 files changed

+50
-33
lines changed

3 files changed

+50
-33
lines changed

.github/workflows/jupyter.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
run: |
3232
python -m pip install --upgrade pip
3333
pip install nbconvert nbformat jupyter
34-
pip install sympy
34+
pip install "sympy<=1.12.1" # higher versions yield wrong calculations
3535
- name: Run Notebooks
3636
run: |
3737
cd Tools/Algorithms

Tools/Algorithms/psatd.ipynb

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@
77
"outputs": [],
88
"source": [
99
"import sympy as sp\n",
10+
"from packaging.version import parse\n",
1011
"from sympy import * # noqa\n",
1112
"\n",
1213
"sp.init_session()\n",
13-
"sp.init_printing()"
14+
"sp.init_printing()\n",
15+
"\n",
16+
"sp_version_compatible = \"1.12.1\"\n",
17+
"if parse(sp.__version__) > parse(sp_version_compatible):\n",
18+
" raise ValueError(\n",
19+
" f\"Versions of sympy>{sp_version_compatible} do not yield correct results, please use sympy<={sp_version_compatible}\"\n",
20+
" )"
1421
]
1522
},
1623
{
@@ -38,7 +45,7 @@
3845
"divE_cleaning = True\n",
3946
"divB_cleaning = True\n",
4047
"J_in_time = \"constant\"\n",
41-
"rho_in_time = \"constant\""
48+
"rho_in_time = \"linear\""
4249
]
4350
},
4451
{
@@ -65,8 +72,9 @@
6572
" diff = W[i, j] - Wd[i, j]\n",
6673
" diff = diff.expand().simplify()\n",
6774
" if diff != 0:\n",
68-
" print(rf\"Could Not Verify Diagonalization for Component ({i},{j}):\")\n",
69-
" display(diff)\n",
75+
" raise ValueError(\n",
76+
" f\"Could not verify diagonalization for component ({i},{j})\"\n",
77+
" )\n",
7078
"\n",
7179
"\n",
7280
"def simple_mat(W):\n",
@@ -557,8 +565,7 @@
557565
" diff = lhs - rhs\n",
558566
" diff = diff.simplify()\n",
559567
" if diff != 0:\n",
560-
" print(r\"Could Not Verify Time Integration\")\n",
561-
" display(diff)"
568+
" raise ValueError(\"Could not verify time integration\")"
562569
]
563570
},
564571
{
@@ -601,7 +608,7 @@
601608
" .trigsimp()\n",
602609
" .simplify()\n",
603610
" )\n",
604-
" print(f\"Coefficient of {L[i]} with respect to {R[j]}:\")\n",
611+
" print(rf\"Coefficient of {L[i]} with respect to {R[j]}:\")\n",
605612
" display(coeff_h[key])"
606613
]
607614
},
@@ -681,7 +688,7 @@
681688
" .trigsimp()\n",
682689
" .simplify()\n",
683690
" )\n",
684-
" print(f\"Coefficient of {L[i]} Multiplying {R[j]}:\")\n",
691+
" print(rf\"Coefficient of {L[i]} Multiplying {R[j]}:\")\n",
685692
" display(coeff_nh[key])"
686693
]
687694
},
@@ -711,7 +718,7 @@
711718
"name": "python",
712719
"nbconvert_exporter": "python",
713720
"pygments_lexer": "ipython3",
714-
"version": "3.13.1"
721+
"version": "3.13.3"
715722
}
716723
},
717724
"nbformat": 4,

Tools/Algorithms/psatd_pml.ipynb

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@
99
"outputs": [],
1010
"source": [
1111
"import sympy as sp\n",
12+
"from packaging.version import parse\n",
1213
"from sympy import * # noqa\n",
1314
"from sympy.solvers.solveset import linsolve\n",
1415
"\n",
1516
"sp.init_session()\n",
16-
"sp.init_printing()"
17+
"sp.init_printing()\n",
18+
"\n",
19+
"sp_version_compatible = \"1.12.1\"\n",
20+
"if parse(sp.__version__) > parse(sp_version_compatible):\n",
21+
" raise ValueError(\n",
22+
" f\"Versions of sympy>{sp_version_compatible} do not yield correct results, please use sympy<={sp_version_compatible}\"\n",
23+
" )"
1724
]
1825
},
1926
{
@@ -151,7 +158,7 @@
151158
" MX_charpoly = MX_charmat.det()\n",
152159
" MX_charpoly = factor(MX_charpoly.as_expr())\n",
153160
"\n",
154-
" print(r\"Characteristic Polynomial:\")\n",
161+
" print(r\"Characteristic polynomial:\")\n",
155162
" display(MX_charpoly)\n",
156163
"\n",
157164
" MX_eigenvals = sp.solve(MX_charpoly, lamda)\n",
@@ -171,12 +178,12 @@
171178
" print(r\"Eigenvalue:\")\n",
172179
" display(ev)\n",
173180
"\n",
174-
" print(r\"Characteristic Matrix:\")\n",
181+
" print(r\"Characteristic matrix:\")\n",
175182
" display(A)\n",
176183
"\n",
177184
" # Perform Gaussian elimination (necessary for lamda != 0)\n",
178185
" if ev != 0.0:\n",
179-
" print(r\"Gaussian Elimination:\")\n",
186+
" print(r\"Gaussian elimination:\")\n",
180187
" print(r\"A[0,:] += A[1,:]\")\n",
181188
" A[0, :] += A[1, :]\n",
182189
" print(r\"A[0,:] += A[2,:]\")\n",
@@ -235,8 +242,9 @@
235242
" diff = MX * ep[2][j] - ep[0] * ep[2][j]\n",
236243
" diff.simplify()\n",
237244
" if diff != zeros(DD, 1):\n",
238-
" print(\"Could Not Verify Characteristic Equation for Some Eigenpairs\")\n",
239-
" display(diff)\n",
245+
" raise ValueError(\n",
246+
" \"Could not verify characteristic equation for some eigenpairs\"\n",
247+
" )\n",
240248
"\n",
241249
" # Define integration constants\n",
242250
" a = []\n",
@@ -276,8 +284,7 @@
276284
" diff = X_t.diff(t).diff(t).subs(t, tn).subs(om, c * knorm).expand() - d2X_dt2\n",
277285
" diff.simplify()\n",
278286
" if diff != zeros(DD, 1):\n",
279-
" print(rf\"Could Not Verify Time Integration for {str(X)}\")\n",
280-
" display(diff)\n",
287+
" raise ValueError(rf\"Could not verify time integration for {str(X)}\")\n",
281288
"\n",
282289
" return X_t, X_new"
283290
]
@@ -565,10 +572,10 @@
565572
"metadata": {},
566573
"outputs": [],
567574
"source": [
568-
"print(r\"Solve Equations for E and G:\")\n",
575+
"print(r\"Solving equations for E and G...\")\n",
569576
"EG_t, EG_new = evolve(EG, dEG_dt, d2EG_dt2)\n",
570577
"\n",
571-
"print(r\"Solve Equations for B and F:\")\n",
578+
"print(r\"Solving equations for B and F...\")\n",
572579
"BF_t, BF_new = evolve(BF, dBF_dt, d2BF_dt2)\n",
573580
"\n",
574581
"# Check correctness by taking *first* derivative\n",
@@ -577,14 +584,12 @@
577584
"diff = EG_t.diff(t).subs(t, tn).subs(om, c * knorm).expand() - dEG_dt\n",
578585
"diff.simplify()\n",
579586
"if diff != zeros(DD, 1):\n",
580-
" print(rf\"Could Not Verify Time Integration for {str(EG)}\")\n",
581-
" display(diff)\n",
587+
" raise ValueError(rf\"Could not verify time integration for {str(EG)}\")\n",
582588
"# B,F\n",
583589
"diff = BF_t.diff(t).subs(t, tn).subs(om, c * knorm).expand() - dBF_dt\n",
584590
"diff.simplify()\n",
585591
"if diff != zeros(DD, 1):\n",
586-
" print(rf\"Could Not Verify Time Integration for {str(BF)}\")\n",
587-
" display(diff)"
592+
" raise ValueError(rf\"Could not verify time integration for {str(BF)}\")"
588593
]
589594
},
590595
{
@@ -600,25 +605,30 @@
600605
"metadata": {},
601606
"outputs": [],
602607
"source": [
603-
"# Code generation\n",
604-
"\n",
605608
"# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11\n",
606609
"# EG: Exx, Exy, Exz, Eyx, Eyy, Eyz, Ezx, Ezy, Ezz, Gx, Gy, Gz\n",
607610
"# BF: Bxx, Bxy, Bxz, Byx, Byy, Byz, Bzx, Bzy, Bzz, Fx, Fy, Fz\n",
608611
"\n",
609612
"# Select update equation (left hand side)\n",
610613
"X_new = sp.Matrix.vstack(EG_new, BF_new)\n",
614+
"X_old = sp.Matrix.vstack(EG, BF)\n",
611615
"for i in range(X_new.shape[0]):\n",
612616
" field_lhs = X_new[i, 0]\n",
617+
" field_lhs_label = X_old[i, 0]\n",
613618
" # Extract individual terms (right hand side)\n",
614-
" X = sp.Matrix.vstack(EG, BF)\n",
615-
" for j in range(X.shape[0]):\n",
616-
" field_rhs = X[j, 0]\n",
619+
" for j in range(X_old.shape[0]):\n",
620+
" field_rhs = X_old[j, 0]\n",
617621
" coeff = field_lhs.coeff(field_rhs, 1).simplify()\n",
618-
" print(rf\"Coefficient of {str(field_rhs)} Multiplying {str(field_rhs)}\")\n",
619-
" display(coeff)\n",
620-
" # print(ccode(Assignment(sp.symbols(r'LHS'), C1)))"
622+
" print(rf\"Coefficient of {str(field_lhs_label)} multiplying {str(field_rhs)}\")\n",
623+
" display(coeff)"
621624
]
625+
},
626+
{
627+
"cell_type": "code",
628+
"execution_count": null,
629+
"metadata": {},
630+
"outputs": [],
631+
"source": []
622632
}
623633
],
624634
"metadata": {
@@ -637,7 +647,7 @@
637647
"name": "python",
638648
"nbconvert_exporter": "python",
639649
"pygments_lexer": "ipython3",
640-
"version": "3.13.1"
650+
"version": "3.13.3"
641651
}
642652
},
643653
"nbformat": 4,

0 commit comments

Comments
 (0)