Skip to content

Commit 7303c49

Browse files
authored
Merge pull request #138 from gdsfactory/fix_models
fix models and add tests for models
2 parents 33d703b + 06a1280 commit 7303c49

26 files changed

+63
-18
lines changed

cspdk/si220/cband/models.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
straight_strip = partial(
2424
sm.straight,
2525
length=10.0,
26-
loss=0.0,
26+
loss_dB_cm=3.0,
2727
wl0=1.55,
2828
neff=2.38,
2929
ng=4.30,
@@ -32,7 +32,7 @@
3232
straight_rib = partial(
3333
sm.straight,
3434
length=10.0,
35-
loss=0.0,
35+
loss_dB_cm=3.0,
3636
wl0=1.55,
3737
neff=2.38,
3838
ng=4.30,
@@ -43,7 +43,7 @@ def straight(
4343
*,
4444
wl: Float = 1.55,
4545
length: float = 10.0,
46-
loss: float = 0.0,
46+
loss_dB_cm: float = 3.0,
4747
cross_section: str = "strip",
4848
) -> sax.SDict:
4949
"""Straight waveguide model."""
@@ -56,7 +56,7 @@ def straight(
5656
return f(
5757
wl=wl, # type: ignore
5858
length=length,
59-
loss=loss,
59+
loss_dB_cm=loss_dB_cm,
6060
)
6161

6262

@@ -76,15 +76,15 @@ def bend_s(
7676
*,
7777
wl: Float = 1.55,
7878
length: float = 10.0,
79-
loss: float = 0.03,
79+
loss_dB_cm=3.0,
8080
cross_section="strip",
8181
) -> sax.SDict:
8282
"""Bend S model."""
8383
# NOTE: it is assumed that `bend_s` exposes it's length in its info dictionary!
8484
return straight(
8585
wl=wl,
8686
length=length,
87-
loss=loss,
87+
loss_dB_cm=loss_dB_cm,
8888
cross_section=cross_section,
8989
)
9090

@@ -93,20 +93,20 @@ def bend_euler(
9393
*,
9494
wl: Float = 1.55,
9595
length: float = 10.0,
96-
loss: float = 0.03,
96+
loss_dB_cm: float = 3,
9797
cross_section="strip",
9898
) -> sax.SDict:
9999
"""Euler bend model."""
100100
# NOTE: it is assumed that `bend_euler` exposes it's length in its info dictionary!
101101
return straight(
102102
wl=wl,
103103
length=length,
104-
loss=loss,
104+
loss_dB_cm=loss_dB_cm,
105105
cross_section=cross_section,
106106
)
107107

108108

109-
bend_euler = partial(bend_euler, cross_section="strip")
109+
bend_euler_strip = partial(bend_euler, cross_section="strip")
110110
bend_euler_rib = partial(bend_euler, cross_section="rib")
111111

112112

@@ -119,7 +119,7 @@ def taper(
119119
*,
120120
wl: Float = 1.55,
121121
length: float = 10.0,
122-
loss: float = 0.0,
122+
loss_dB_cm: float = 0.0,
123123
cross_section="strip",
124124
) -> sax.SDict:
125125
"""Taper model."""
@@ -128,20 +128,19 @@ def taper(
128128
return straight(
129129
wl=wl,
130130
length=length,
131-
loss=loss,
131+
loss_dB_cm=loss_dB_cm,
132132
cross_section=cross_section,
133133
)
134134

135135

136136
taper_rib = partial(taper, cross_section="rib", length=10.0)
137-
taper_ro = partial(taper, cross_section="xs_ro", length=10.0)
138137

139138

140139
def taper_strip_to_ridge(
141140
*,
142141
wl: Float = 1.55,
143142
length: float = 10.0,
144-
loss: float = 0.0,
143+
loss_dB_cm: float = 0.0,
145144
cross_section="strip",
146145
) -> sax.SDict:
147146
"""Taper strip to ridge model."""
@@ -150,7 +149,7 @@ def taper_strip_to_ridge(
150149
return straight(
151150
wl=wl,
152151
length=length,
153-
loss=loss,
152+
loss_dB_cm=loss_dB_cm,
154153
cross_section=cross_section,
155154
)
156155

@@ -239,10 +238,10 @@ def coupler(
239238
# grating couplers Rectangular
240239
##############################
241240

242-
grating_coupler_rectangular = partial(
241+
grating_coupler_rectangular_strip = partial(
243242
sm.grating_coupler, loss=6, bandwidth=35 * nm, wl=1.55
244243
)
245-
grating_coupler_rectangular_rib = grating_coupler_rectangular
244+
grating_coupler_rectangular_rib = grating_coupler_rectangular_strip
246245

247246

248247
def grating_coupler_rectangular(
@@ -253,7 +252,7 @@ def grating_coupler_rectangular(
253252
# TODO: take more grating_coupler_rectangular arguments into account
254253
wl = jnp.asarray(wl) # type: ignore
255254
fs = {
256-
"strip": grating_coupler_rectangular,
255+
"strip": grating_coupler_rectangular_strip,
257256
"rib": grating_coupler_rectangular_rib,
258257
}
259258
f = fs[cross_section]
@@ -264,7 +263,7 @@ def grating_coupler_rectangular(
264263
# grating couplers Elliptical
265264
##############################
266265

267-
grating_coupler_elliptical = partial(
266+
grating_coupler_elliptical_strip = partial(
268267
sm.grating_coupler, loss=6, bandwidth=35 * nm, wl=1.55
269268
)
270269

tests/test_si220_cband.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
import gdsfactory as gf
88
import jsondiff
9+
import numpy as np
910
import pytest
1011
from gdsfactory.difftest import difftest
1112
from pytest_regressions.data_regression import DataRegressionFixture
13+
from pytest_regressions.ndarrays_regression import NDArraysRegressionFixture
1214

1315
from cspdk.si220.cband import PDK
1416

@@ -130,6 +132,50 @@ def test_netlists(
130132
gf.kcl.dkcells[ci].delete()
131133

132134

135+
skip_test_models = {}
136+
137+
138+
models = PDK.models
139+
model_names = sorted(
140+
[
141+
name
142+
for name in set(models.keys()) - set(skip_test_models)
143+
if not name.startswith("_")
144+
]
145+
)
146+
147+
148+
@pytest.mark.parametrize("model_name", model_names)
149+
def test_models_with_wavelength_sweep(
150+
model_name: str, ndarrays_regression: NDArraysRegressionFixture
151+
) -> None:
152+
"""Test models with different wavelengths to avoid regressions in frequency response."""
153+
# Test at different wavelengths
154+
wl = [1.53, 1.55, 1.57]
155+
try:
156+
model = models[model_name]
157+
s_params = model(wl=wl)
158+
except TypeError:
159+
pytest.skip(f"{model_name} does not accept a wl argument")
160+
161+
# Convert s_params dictionary to arrays for regression testing
162+
# s_params is a dict with tuple keys (port pairs) and JAX array values
163+
arrays_to_check = {}
164+
for key, value in sorted(s_params.items()):
165+
# Convert tuple key to string for regression test compatibility
166+
key_str = f"s_{key[0]}_{key[1]}"
167+
# Convert JAX arrays to numpy and separate real/imag parts
168+
169+
value_np = np.array(value)
170+
arrays_to_check[f"{key_str}_real"] = np.real(value_np)
171+
arrays_to_check[f"{key_str}_imag"] = np.imag(value_np)
172+
173+
ndarrays_regression.check(
174+
arrays_to_check,
175+
default_tolerance={"atol": 1e-8, "rtol": 1e-8},
176+
)
177+
178+
133179
if __name__ == "__main__":
134180
component_type = "coupler_symmetric"
135181
c = cells[component_type]()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)