Skip to content

Commit 96d64e3

Browse files
momchil-flexyaugenst-flex
authored andcommitted
Turning a few validators into root validators
1 parent f9680d1 commit 96d64e3

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

tidy3d/components/base_sim/data/sim_data.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ def monitor_data(self) -> dict[str, AbstractMonitorData]:
5252
"""Dictionary mapping monitor name to its associated :class:`AbstractMonitorData`."""
5353
return {monitor_data.monitor.name: monitor_data for monitor_data in self.data}
5454

55-
@pd.validator("data", always=True)
56-
@skip_if_fields_missing(["simulation"])
57-
def data_monitors_match_sim(cls, val, values):
55+
@pd.root_validator(skip_on_failure=True)
56+
def data_monitors_match_sim(cls, values):
5857
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
5958
``.simulation``.
6059
"""
6160
sim = values.get("simulation")
61+
data = values.get("data")
6262

63-
for mnt_data in val:
63+
for mnt_data in data:
6464
try:
6565
monitor_name = mnt_data.monitor.name
6666
sim.get_monitor_by_name(monitor_name)
@@ -69,7 +69,7 @@ def data_monitors_match_sim(cls, val, values):
6969
f"Data with monitor name '{monitor_name}' supplied "
7070
f"but not found in the original '{sim.type}'."
7171
) from exc
72-
return val
72+
return values
7373

7474
@pd.validator("data", always=True)
7575
@skip_if_fields_missing(["simulation"])

tidy3d/components/data/unstructured/base.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,21 @@ def values_right_indexing(cls, val):
114114
)
115115
return val
116116

117-
@pd.validator("values", always=True)
118-
@skip_if_fields_missing(["points"])
119-
def number_of_values_matches_points(cls, val, values):
117+
@pd.root_validator(skip_on_failure=True)
118+
def number_of_values_matches_points(cls, values):
120119
"""Check that the number of data values matches the number of grid points."""
121-
num_values = len(val.index)
122-
123120
points = values.get("points")
124-
num_points = len(points)
125-
if num_points != num_values:
126-
raise ValidationError(
127-
f"The number of data values ({num_values}) does not match the number of grid "
128-
f"points ({num_points})."
129-
)
130-
return val
121+
vals = values.get("values")
122+
123+
if points is not None and vals is not None:
124+
num_points = len(points)
125+
num_values = len(vals.index)
126+
if num_points != num_values:
127+
raise ValidationError(
128+
f"The number of data values ({num_values}) does not match the number of grid "
129+
f"points ({num_points})."
130+
)
131+
return values
131132

132133
@pd.validator("cells", always=True)
133134
def match_cells_to_vtk_type(cls, val):

tidy3d/components/tcad/data/sim_data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import pydantic.v1 as pd
1010

11-
from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing
11+
from tidy3d.components.base import Tidy3dBaseModel
1212
from tidy3d.components.base_sim.data.sim_data import AbstractSimulationData
1313
from tidy3d.components.data.data_array import (
1414
SpatialDataArray,
@@ -560,23 +560,23 @@ def mesher(self) -> VolumeMesher:
560560
monitors=self.monitors,
561561
)
562562

563-
@pd.validator("data", always=True)
564-
@skip_if_fields_missing(["monitors"])
565-
def data_monitors_match_sim(cls, val, values):
563+
@pd.root_validator(skip_on_failure=True)
564+
def data_monitors_match_sim(cls, values):
566565
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
567566
``.simulation``.
568567
"""
569568
monitors = values.get("monitors")
569+
data = values.get("data")
570570
mnt_names = {mnt.name for mnt in monitors}
571571

572-
for mnt_data in val:
572+
for mnt_data in data:
573573
monitor_name = mnt_data.monitor.name
574574
if monitor_name not in mnt_names:
575575
raise DataError(
576576
f"Data with monitor name '{monitor_name}' supplied "
577577
f"but not found in the list of monitors."
578578
)
579-
return val
579+
return values
580580

581581
def get_monitor_by_name(self, name: str) -> VolumeMeshMonitor:
582582
"""Return monitor named 'name'."""

0 commit comments

Comments
 (0)