Skip to content

Commit e868553

Browse files
committed
clean up stretch factors
1 parent 7aed028 commit e868553

File tree

1 file changed

+55
-52
lines changed

1 file changed

+55
-52
lines changed

test/test_symbolic.py

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -500,28 +500,45 @@ def make_gmsh_torus(order: int, cls: type):
500500
)
501501

502502

503-
def make_simplex_stretch_factors(ambient_dim):
503+
def metric_from_form1(form1, metric_type: str):
504+
from pytential.symbolic.primitives import _small_mat_eigenvalues
505+
s0, s1 = _small_mat_eigenvalues(4 * form1)
506+
507+
if metric_type == "singvals":
508+
return np.array([sym.sqrt(s0), sym.sqrt(s1)], dtype=object)
509+
elif metric_type == "det":
510+
return np.array([s0 * s1], dtype=object)
511+
elif metric_type == "trace":
512+
return np.array([s0 + s1], dtype=object)
513+
elif metric_type == "norm":
514+
return np.array([sym.sqrt(s0**2 + s1**2)], dtype=object)
515+
elif metric_type == "aspect":
516+
return np.array([
517+
(s0 * s1)**(2/3) / (s0**2 + s1**2),
518+
], dtype=object)
519+
elif metric_type == "condition":
520+
import pymbolic.primitives as prim
521+
return np.array([
522+
prim.Max((s0, s1)) / prim.Min((s0, s1))
523+
], dtype=object)
524+
else:
525+
raise ValueError(f"unknown metric type: '{metric_type}'")
526+
527+
528+
def make_simplex_stretch_factors(ambient_dim: int, metric_type: str):
504529
from pytential.symbolic.primitives import \
505530
_equilateral_parametrization_derivative_matrix
506531
equi_pder = _equilateral_parametrization_derivative_matrix(ambient_dim)
507532
equi_form1 = sym.cse(equi_pder.T @ equi_pder, "pd_mat_jtj")
508533

509-
from pytential.symbolic.primitives import _small_mat_eigenvalues
510-
return [
511-
sym.cse(sym.sqrt(s), f"mapping_singval_{i}")
512-
for i, s in enumerate(_small_mat_eigenvalues(4 * equi_form1))
513-
]
534+
return metric_from_form1(equi_form1, metric_type)
514535

515536

516-
def make_quad_stretch_factors(ambient_dim):
537+
def make_quad_stretch_factors(ambient_dim: int, metric_type: str):
517538
pder = sym.parametrization_derivative_matrix(ambient_dim, ambient_dim - 1)
518539
form1 = sym.cse(pder.T @ pder, "pd_mat_jtj")
519540

520-
from pytential.symbolic.primitives import _small_mat_eigenvalues
521-
return [
522-
sym.cse(sym.sqrt(s), f"mapping_singval_{i}")
523-
for i, s in enumerate(_small_mat_eigenvalues(4 * form1))
524-
]
541+
return metric_from_form1(form1, metric_type)
525542

526543

527544
@pytest.mark.parametrize("order", [4, 8])
@@ -569,42 +586,22 @@ def test_stretch_factor(actx_factory, order,
569586
print(f"simplex_discr.ndofs: {simplex_discr.ndofs}")
570587
print(f"quad_discr.ndofs: {quad_discr.ndofs}")
571588

572-
if metric_type == "eigs":
573-
sym_simplex = make_simplex_stretch_factors(ambient_dim)
574-
sym_quad = make_quad_stretch_factors(ambient_dim)
575-
elif metric_type == "det":
576-
form1 = sym.first_fundamental_form(ambient_dim)
577-
sym_simplex = np.array([
578-
form1[0, 0] * form1[1, 1] - form1[0, 1] * form1[1, 0],
579-
sym.Ones(),
580-
], dtype=object)
581-
sym_quad = sym_simplex
582-
elif metric_type == "trace":
583-
form1 = sym.first_fundamental_form(ambient_dim)
584-
sym_simplex = np.array([
585-
form1[0, 0] + form1[1, 1],
586-
sym.Ones(),
587-
], dtype=object)
588-
sym_quad = sym_simplex
589-
elif metric_type == "norm":
590-
form1 = sym.first_fundamental_form(ambient_dim)
591-
sym_simplex = np.array([
592-
sym.sqrt(sum(form1 * form1)),
593-
sym.Ones(),
594-
], dtype=object)
595-
sym_quad = sym_simplex
596-
else:
597-
raise ValueError(f"unknown metric type: '{metric_type}'")
589+
sym_simplex = make_simplex_stretch_factors(ambient_dim, metric_type)
590+
sym_quad = make_quad_stretch_factors(ambient_dim, metric_type)
591+
592+
s = bind(simplex_discr, sym_simplex)(actx)
593+
q = bind(quad_discr, sym_quad)(actx)
598594

599-
s0, s1 = bind(simplex_discr, sym_simplex)(actx)
600-
q0, q1 = bind(quad_discr, sym_quad)(actx)
595+
def print_bounds(x, name):
596+
for i, si in enumerate(x):
597+
print("{}{} [{:.12e}, {:.12e}]".format(
598+
name, i,
599+
actx.to_numpy(actx.np.min(si))[()],
600+
actx.to_numpy(actx.np.min(si))[()]
601+
))
601602

602-
print("s0")
603-
print(actx.to_numpy(actx.np.min(s0))[()], actx.to_numpy(actx.np.max(s0))[()])
604-
print(actx.to_numpy(actx.np.min(q0))[()], actx.to_numpy(actx.np.max(q0))[()])
605-
print("s1")
606-
print(actx.to_numpy(actx.np.min(s1))[()], actx.to_numpy(actx.np.max(s1))[()])
607-
print(actx.to_numpy(actx.np.min(q1))[()], actx.to_numpy(actx.np.max(q1))[()])
603+
print_bounds(s, "s")
604+
print_bounds(q, "q")
608605

609606
if not visualize:
610607
return
@@ -615,17 +612,23 @@ def test_stretch_factor(actx_factory, order,
615612

616613
from meshmode.discretization.visualization import make_visualizer
617614
vis = make_visualizer(actx, simplex_discr, order, force_equidistant=True)
618-
vis.write_vtk_file(f"simplex_{suffix}.vtu", [
619-
("s0", s0), ("s1", s1),
620-
], overwrite=True, use_high_order=True)
615+
vis.write_vtk_file(f"simplex_{suffix}.vtu",
616+
[(f"s{i}", si) for i, si in enumerate(s)],
617+
overwrite=True, use_high_order=True)
621618

622619
vis = make_visualizer(actx, quad_discr, order, force_equidistant=True)
623-
vis.write_vtk_file(f"quad_{suffix}.vtu", [
624-
("s0", q0), ("s1", q1),
625-
], overwrite=True, use_high_order=True)
620+
vis.write_vtk_file(f"quad_{suffix}.vtu",
621+
[(f"q{i}", qi) for i, qi in enumerate(q)],
622+
overwrite=True, use_high_order=True)
626623

627624
# }}}
628625

626+
if s.size != 2:
627+
return
628+
629+
s0, s1 = s
630+
q0, q1 = q
631+
629632
# {{{ plot reference simplex
630633

631634
if quad_discr.mesh.nelements <= 2:

0 commit comments

Comments
 (0)