Skip to content

Commit e0e8dbb

Browse files
caseyflexyaugenst-flex
authored andcommitted
Add option to use dispersion fitter without rich.progress
1 parent 86ea96e commit e0e8dbb

File tree

4 files changed

+224
-96
lines changed

4 files changed

+224
-96
lines changed

tests/test_plugins/test_dispersion_fitter.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from __future__ import annotations
22

3+
import io
4+
from unittest import mock
5+
36
import matplotlib.pyplot as plt
47
import numpy as np
58
import pydantic.v1 as pydantic
69
import pytest
710
import responses
11+
import rich
12+
from rich.progress import Progress
813

914
import tidy3d as td
1015
from tidy3d.exceptions import SetupError, ValidationError
@@ -288,3 +293,37 @@ def test_dispersion_loss_samples():
288293
ep = nAlGaN_mat.eps_model(freq_list)
289294
for e in ep:
290295
assert e.imag >= 0
296+
297+
298+
def test_dispersion_show_progress():
299+
eps_real = 2.5
300+
loss_tangent = 1e-2
301+
frequency_range = (1e9, 6e9)
302+
303+
console_out = io.StringIO()
304+
test_console = rich.console.Console(file=console_out, force_terminal=True)
305+
original_init = Progress.__init__
306+
307+
def patched_init(self, *args, **kwargs):
308+
kwargs["console"] = test_console
309+
original_init(self, *args, **kwargs)
310+
311+
with mock.patch("rich.progress.Progress.__init__", patched_init):
312+
mat = FastDispersionFitter.constant_loss_tangent_model(
313+
eps_real, loss_tangent, frequency_range, show_progress=True
314+
)
315+
316+
with_progress = console_out.getvalue()
317+
318+
console_out.truncate(0)
319+
console_out.seek(0)
320+
321+
mat = FastDispersionFitter.constant_loss_tangent_model(
322+
eps_real, loss_tangent, frequency_range, show_progress=False
323+
)
324+
without_progress = console_out.getvalue()
325+
326+
print(with_progress)
327+
print(without_progress)
328+
329+
assert len(str(with_progress)) > len(str(without_progress))

tidy3d/components/dispersion_fitter.py

Lines changed: 129 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66

77
import numpy as np
88
from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator
9-
from rich.progress import Progress
109

1110
from tidy3d.constants import fp_eps
1211
from tidy3d.exceptions import ValidationError
13-
from tidy3d.log import get_logging_console, log
12+
from tidy3d.log import Progress, get_logging_console, log
1413

1514
from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
1615
from .types import ArrayComplex1D, ArrayComplex2D, ArrayFloat1D, ArrayFloat2D
@@ -822,12 +821,11 @@ def fit(
822821
823822
Returns
824823
-------
825-
Tuple[Tuple[float, ArrayComplex1D, ArrayComplex1D], float]
824+
tuple[tuple[float, ArrayComplex1D, ArrayComplex1D], float]
826825
Best fitting result: (dispersive medium parameters, weighted RMS error).
827826
The dispersive medium parameters have the form (resp_inf, poles, residues)
828827
and are in the original unscaled units.
829828
"""
830-
831829
if max_num_poles < min_num_poles:
832830
raise ValidationError(
833831
"Dispersion fitter cannot have 'max_num_poles' less than 'min_num_poles'."
@@ -866,88 +864,80 @@ def make_configs():
866864

867865
configs = make_configs()
868866

869-
with Progress(console=get_logging_console()) as progress:
867+
with Progress(
868+
console=get_logging_console(), show_progress=init_model.show_progress
869+
) as progress:
870870
task = progress.add_task(
871-
f"Fitting to weighted RMS of {tolerance_rms}...",
871+
description=f"Fitting to weighted RMS of {tolerance_rms}...",
872872
total=len(configs),
873873
visible=init_model.show_progress,
874874
)
875875

876-
while not progress.finished:
877-
# try different initial pole configurations
878-
for num_poles, relaxed, smooth, logspacing, optimize_eps_inf in configs:
879-
model = init_model.updated_copy(
880-
num_poles=num_poles,
881-
relaxed=relaxed,
882-
smooth=smooth,
883-
logspacing=logspacing,
884-
optimize_eps_inf=optimize_eps_inf,
876+
# try different initial pole configurations
877+
for num_poles, relaxed, smooth, logspacing, optimize_eps_inf in configs:
878+
model = init_model.updated_copy(
879+
num_poles=num_poles,
880+
relaxed=relaxed,
881+
smooth=smooth,
882+
logspacing=logspacing,
883+
optimize_eps_inf=optimize_eps_inf,
884+
)
885+
model = _fit_fixed_parameters((min_num_poles, max_num_poles), model)
886+
887+
if model.rms_error < best_model.rms_error:
888+
log.debug(
889+
f"Fitter: possible improved fit with "
890+
f"rms_error={model.rms_error:.3g} found using "
891+
f"relaxed={model.relaxed}, "
892+
f"smooth={model.smooth}, "
893+
f"logspacing={model.logspacing}, "
894+
f"optimize_eps_inf={model.optimize_eps_inf}, "
895+
f"loss_in_bounds={model.loss_in_bounds}, "
896+
f"passivity_optimized={model.passivity_optimized}, "
897+
f"sellmeier_passivity={model.sellmeier_passivity}."
885898
)
886-
model = _fit_fixed_parameters((min_num_poles, max_num_poles), model)
887-
888-
if model.rms_error < best_model.rms_error:
889-
log.debug(
890-
f"Fitter: possible improved fit with "
891-
f"rms_error={model.rms_error:.3g} found using "
892-
f"relaxed={model.relaxed}, "
893-
f"smooth={model.smooth}, "
894-
f"logspacing={model.logspacing}, "
895-
f"optimize_eps_inf={model.optimize_eps_inf}, "
896-
f"loss_in_bounds={model.loss_in_bounds}, "
897-
f"passivity_optimized={model.passivity_optimized}, "
898-
f"sellmeier_passivity={model.sellmeier_passivity}."
899-
)
900-
if model.loss_in_bounds and model.sellmeier_passivity:
901-
best_model = model
902-
else:
903-
if (
904-
not warned_about_passivity_num_iters
905-
and model.passivity_num_iters_too_small
906-
):
907-
warned_about_passivity_num_iters = True
908-
log.warning(
909-
"Did not finish enforcing passivity in dispersion fitter. "
910-
"If the fit is not good enough, consider increasing "
911-
"'AdvancedFastFitterParam.passivity_num_iters'."
912-
)
913-
if (
914-
not warned_about_slsqp_constraint_scale
915-
and model.slsqp_constraint_scale_too_small
916-
):
917-
warned_about_slsqp_constraint_scale = True
918-
log.warning(
919-
"SLSQP constraint scale may be too small. "
920-
"If the fit is not good enough, consider increasing "
921-
"'AdvancedFastFitterParam.slsqp_constraint_scale'."
922-
)
899+
if model.loss_in_bounds and model.sellmeier_passivity:
900+
best_model = model
901+
else:
902+
if not warned_about_passivity_num_iters and model.passivity_num_iters_too_small:
903+
warned_about_passivity_num_iters = True
904+
log.warning(
905+
"Did not finish enforcing passivity in dispersion fitter. "
906+
"If the fit is not good enough, consider increasing "
907+
"'AdvancedFastFitterParam.passivity_num_iters'."
908+
)
909+
if (
910+
not warned_about_slsqp_constraint_scale
911+
and model.slsqp_constraint_scale_too_small
912+
):
913+
warned_about_slsqp_constraint_scale = True
914+
log.warning(
915+
"SLSQP constraint scale may be too small. "
916+
"If the fit is not good enough, consider increasing "
917+
"'AdvancedFastFitterParam.slsqp_constraint_scale'."
918+
)
919+
progress.update(
920+
task,
921+
advance=1,
922+
description=f"Best weighted RMS error so far: {best_model.rms_error:.3g}",
923+
refresh=True,
924+
)
925+
926+
# if below tolerance, return
927+
if best_model.rms_error < tolerance_rms:
923928
progress.update(
924929
task,
925-
advance=1,
926-
description=f"Best weighted RMS error so far: {best_model.rms_error:.3g}",
930+
completed=len(configs),
931+
description=f"Best weighted RMS error: {best_model.rms_error:.3g}",
927932
refresh=True,
928933
)
929-
930-
# if below tolerance, return
931-
if best_model.rms_error < tolerance_rms:
932-
progress.update(
933-
task,
934-
completed=len(configs),
935-
description=f"Best weighted RMS error: {best_model.rms_error:.3g}",
936-
refresh=True,
937-
)
938-
log.info(
939-
"Found optimal fit with weighted RMS error %.3g",
940-
best_model.rms_error,
941-
)
942-
if best_model.show_unweighted_rms:
943-
log.info(
944-
"Unweighted RMS error %.3g",
945-
best_model.unweighted_rms_error,
946-
)
947-
return (
948-
best_model.pole_residue,
949-
best_model.rms_error,
950-
)
934+
log.info(f"Found optimal fit with weighted RMS error {best_model.rms_error:.3g}")
935+
if best_model.show_unweighted_rms:
936+
log.info(f"Unweighted RMS error {best_model.unweighted_rms_error:.3g}")
937+
return (
938+
best_model.pole_residue,
939+
best_model.rms_error,
940+
)
951941

952942
# if exited loop, did not reach tolerance (warn)
953943
progress.update(
@@ -958,16 +948,73 @@ def make_configs():
958948
)
959949

960950
log.warning(
961-
"Unable to fit with weighted RMS error under 'tolerance_rms' of %.3g", tolerance_rms
951+
f"Unable to fit with weighted RMS error under 'tolerance_rms' of {tolerance_rms:.3g}"
962952
)
963-
log.info("Returning best fit with weighted RMS error %.3g", best_model.rms_error)
953+
log.info(f"Returning best fit with weighted RMS error {best_model.rms_error:.3g}")
964954
if best_model.show_unweighted_rms:
965-
log.info(
966-
"Unweighted RMS error %.3g",
967-
best_model.unweighted_rms_error,
968-
)
955+
log.info(f"Unweighted RMS error {best_model.unweighted_rms_error:.3g}")
969956

970957
return (
971958
best_model.pole_residue,
972959
best_model.rms_error,
973960
)
961+
962+
963+
def constant_loss_tangent_model(
964+
eps_real: float,
965+
loss_tangent: float,
966+
frequency_range: tuple[float, float],
967+
max_num_poles: PositiveInt = DEFAULT_MAX_POLES,
968+
number_sampling_frequency: PositiveInt = 10,
969+
tolerance_rms: NonNegativeFloat = DEFAULT_TOLERANCE_RMS,
970+
scale_factor: float = 1,
971+
show_progress: bool = True,
972+
) -> tuple[tuple[float, ArrayComplex1D, ArrayComplex1D], float]:
973+
"""Fit a constant loss tangent material model.
974+
975+
Parameters
976+
----------
977+
eps_real : float
978+
Real part of permittivity
979+
loss_tangent : float
980+
Loss tangent.
981+
frequency_range : tuple[float, float]
982+
Freqquency range for the material to exhibit constant loss tangent response.
983+
max_num_poles : PositiveInt, optional
984+
Maximum number of poles in the model.
985+
number_sampling_frequency : PositiveInt, optional
986+
Number of sampling frequencies to compute RMS error for fitting.
987+
tolerance_rms : float, optional
988+
Weighted RMS error below which the fit is successful and the result is returned.
989+
scale_factor : PositiveFloat, optional
990+
Factor to rescale frequency by before fitting.
991+
show_progress : bool
992+
Whether to show a progress bar.
993+
994+
Returns
995+
-------
996+
tuple[tuple[float, ArrayComplex1D, ArrayComplex1D], float]
997+
Best fitting result: (dispersive medium parameters, weighted RMS error).
998+
The dispersive medium parameters have the form (resp_inf, poles, residues)
999+
and are in the original unscaled units.
1000+
"""
1001+
if number_sampling_frequency < 2:
1002+
frequencies = np.array([np.mean(frequency_range)])
1003+
else:
1004+
frequencies = np.linspace(frequency_range[0], frequency_range[1], number_sampling_frequency)
1005+
eps_real_array = np.ones_like(frequencies) * eps_real
1006+
loss_tangent_array = np.ones_like(frequencies) * loss_tangent
1007+
1008+
omega_data = frequencies * 2 * np.pi
1009+
eps_complex = eps_real_array * (1 + 1j * loss_tangent_array)
1010+
1011+
advanced_param = AdvancedFastFitterParam(show_progress=show_progress)
1012+
1013+
return fit(
1014+
omega_data=omega_data,
1015+
resp_data=eps_complex,
1016+
max_num_poles=max_num_poles,
1017+
tolerance_rms=tolerance_rms,
1018+
scale_factor=scale_factor,
1019+
advanced_param=advanced_param,
1020+
)

tidy3d/log.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import inspect
6+
from contextlib import contextmanager
67
from datetime import datetime
78
from typing import Callable, Optional, Union
89

@@ -444,3 +445,33 @@ def get_logging_console() -> Console:
444445
if "console" not in log.handlers:
445446
set_logging_console()
446447
return log.handlers["console"].console
448+
449+
450+
class NoOpProgress:
451+
"""Dummy progress manager that doesn't show any output."""
452+
453+
def __enter__(self):
454+
return self
455+
456+
def __exit__(self, *args, **kwargs):
457+
pass
458+
459+
def add_task(self, *args, **kwargs):
460+
pass
461+
462+
def update(self, *args, **kwargs):
463+
pass
464+
465+
466+
@contextmanager
467+
def Progress(console, show_progress):
468+
"""Progress manager that wraps ``rich.Progress`` if ``show_progress`` is ``True``,
469+
and ``NoOpProgress`` otherwise."""
470+
if show_progress:
471+
from rich.progress import Progress
472+
473+
with Progress(console=console) as progress:
474+
yield progress
475+
else:
476+
with NoOpProgress() as progress:
477+
yield progress

0 commit comments

Comments
 (0)