Skip to content

Commit 3e82ac9

Browse files
add example + barebones schema autogenerator
1 parent 757a136 commit 3e82ac9

File tree

3 files changed

+153
-21
lines changed

3 files changed

+153
-21
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Automatically generate schemas from existing data using pandas."""
2+
3+
from enum import StrEnum
4+
from typing import Any, Type, Annotated
5+
6+
from emmet.core.types.typing import NullableDateTimeType, DateTimeType
7+
import pandas as pd
8+
from pathlib import Path
9+
from pydantic import BaseModel, Field, model_validator, create_model, BeforeValidator
10+
11+
_complex_type_validator = BeforeValidator(lambda x : (x.real,x.imag) if isinstance(x,complex) else x)
12+
13+
ComplexType = Annotated[
14+
tuple[float,float],
15+
_complex_type_validator
16+
]
17+
18+
NullableComplexType = Annotated[
19+
tuple[float,float] | None,
20+
_complex_type_validator
21+
]
22+
23+
class FileFormat(StrEnum):
24+
"""Define known file formats for autogeneration of schemae."""
25+
26+
CSV = "csv"
27+
JSON = "json"
28+
JSONL = "jsonl"
29+
30+
class SchemaGenerator(BaseModel):
31+
"""Automatically infer a dataset schema and create a pydantic model from it."""
32+
33+
file_name : str | Path = Field(
34+
description="The path to the dataset."
35+
)
36+
37+
fmt : FileFormat | None = Field(
38+
None, description = "The dataset file format. If no format is provided, it will be inferred."
39+
)
40+
41+
@model_validator(mode="before")
42+
def check_format(cls, config : dict[str,Any]) -> dict[str,Any]:
43+
44+
if isinstance(fp := config["file_name"],str):
45+
config["file_name"] = Path(fp).resolve()
46+
47+
if config.get("fmt"):
48+
if isinstance(config["fmt"],str):
49+
if config["fmt"] in FileFormat.__members__:
50+
config["fmt"] = FileFormat[config["fmt"]]
51+
else:
52+
try:
53+
config["fmt"] = FileFormat(config["fmt"])
54+
except ValueError:
55+
raise ValueError(
56+
f"Could not interpret submitted file format {config['fmt']}"
57+
)
58+
else:
59+
try:
60+
config["fmt"] = next(
61+
file_fmt for file_fmt in FileFormat if file_fmt.value in config["file_name"].name
62+
)
63+
except StopIteration:
64+
raise ValueError(
65+
f"Could not infer file format for {config['file_name']}"
66+
)
67+
return config
68+
69+
@staticmethod
70+
def _cast_dtype(dtype, assume_nullable : bool = True):
71+
"""Cast input dtype to parquet-friendly dtypes.
72+
73+
Accounts for difficulties de-serializing datetimes
74+
and complex numbers.
75+
76+
Assumes all fields are nullable by default.
77+
"""
78+
vname = getattr(dtype,"name",str(dtype)).lower()
79+
80+
if any(spec_type in vname for spec_type in ("datetime","complex")):
81+
if "datetime" in vname:
82+
return NullableDateTimeType if assume_nullable else DateTimeType
83+
elif "complex" in vname:
84+
return NullableComplexType if assume_nullable else ComplexType
85+
86+
inferred_type = str
87+
if "float" in vname:
88+
inferred_type = float
89+
elif "int" in vname:
90+
inferred_type = int
91+
92+
return inferred_type | None if assume_nullable else inferred_type
93+
94+
@property
95+
def pydantic_schema(self) -> Type[BaseModel]:
96+
"""Create the pydantic schema of the data structure."""
97+
98+
if self.fmt == "csv":
99+
data = pd.read_csv(self.file_name)
100+
101+
elif self.fmt in {"json","jsonl"}:
102+
# we exclude the "table" case for `orient` since the user
103+
# presumably already knows what the schema is.
104+
for orient in ("columns","index","records","split","values"):
105+
try:
106+
data = pd.read_json(self.file_name, orient=orient, lines = self.fmt == "jsonl")
107+
break
108+
except Exception as exc:
109+
continue
110+
else:
111+
raise ValueError(
112+
f"Could not load {self.fmt.value} data, please check manually."
113+
)
114+
115+
model_fields = {
116+
col_name : (
117+
self._cast_dtype(data.dtypes[col_name]),
118+
Field(default=None,)
119+
)
120+
for col_name in data.columns
121+
}
122+
123+
return create_model(
124+
f"{self.file_name.name.split(".",1)[0]}",
125+
**model_fields,
126+
)

mpcontribs-lux/mpcontribs/lux/autogen/__init__.py

Whitespace-only changes.

mpcontribs-lux/mpcontribs/lux/projects/examples.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,54 @@
1212

1313
from pymatgen.core import Structure
1414

15+
1516
class ExampleSchema(BaseModel):
1617
"""Define example schema with appropriate levels of annotated metadata."""
1718

18-
formula : str | None = Field(
19-
None, description = "The chemical formula of the unit cell."
19+
formula: str | None = Field(
20+
None, description="The chemical formula of the unit cell."
2021
)
2122

22-
a0 : float | None = Field(
23-
None, description = "The experimental equilibrium cubic "
23+
a0: float | None = Field(
24+
None,
25+
description="The experimental equilibrium cubic "
2426
"lattice constant a, in Å, including zero-point corrections "
25-
"for nuclear vibration."
27+
"for nuclear vibration.",
2628
)
2729

28-
b0 : float | None = Field(
29-
None, description = "The experimental bulk modulus at "
30+
b0: float | None = Field(
31+
None,
32+
description="The experimental bulk modulus at "
3033
"optimal lattice geometry, in GPa, including zero-point "
31-
"corrections for nuclear vibration."
34+
"corrections for nuclear vibration.",
3235
)
3336

34-
e0 : float | None = Field(
35-
None, description = "The experimental cohesive energy, in eV/atom, "
36-
"including zero-point corrections for nuclear vibration."
37+
e0: float | None = Field(
38+
None,
39+
description="The experimental cohesive energy, in eV/atom, "
40+
"including zero-point corrections for nuclear vibration.",
3741
)
3842

39-
cif : str | None = Field(
40-
None, description="The structure represented as a Crystallographic Information File."
43+
cif: str | None = Field(
44+
None,
45+
description="The structure represented as a Crystallographic Information File.",
4146
)
4247

43-
material_id : str | None = Field(
44-
None, description = "The Materials Project ID of the structure which "
45-
"corresponds to this entry. The ID will start with `mp-`"
48+
material_id: str | None = Field(
49+
None,
50+
description="The Materials Project ID of the structure which "
51+
"corresponds to this entry. The ID will start with `mp-`",
4652
)
4753

4854
@cached_property
4955
def get_pymatgen_structure(self) -> Structure | None:
5056
"""Get the pymatgen structure for this entry, if it exists.
51-
52-
Example of adding functionality to downstream users to interact
57+
58+
Example of adding functionality to downstream users to interact
5359
with your data.
54-
60+
5561
You can provide more advanced analysis tools, which we also show below.
5662
"""
5763
if self.cif:
58-
return Structure.from_str(self.cif,fmt="cif")
59-
return None
64+
return Structure.from_str(self.cif, fmt="cif")
65+
return None

0 commit comments

Comments
 (0)