1+ """Automatically generate schemas from existing data using pandas."""
2+
3+ from enum import StrEnum
4+ from typing import Any , Type , Annotated
5+
6+ from emmet .core .types .typing import NullableDateTimeType , DateTimeType
7+ import pandas as pd
8+ from pathlib import Path
9+ from pydantic import BaseModel , Field , model_validator , create_model , BeforeValidator
10+
11+ _complex_type_validator = BeforeValidator (lambda x : (x .real ,x .imag ) if isinstance (x ,complex ) else x )
12+
13+ ComplexType = Annotated [
14+ tuple [float ,float ],
15+ _complex_type_validator
16+ ]
17+
18+ NullableComplexType = Annotated [
19+ tuple [float ,float ] | None ,
20+ _complex_type_validator
21+ ]
22+
23+ class FileFormat (StrEnum ):
24+ """Define known file formats for autogeneration of schemae."""
25+
26+ CSV = "csv"
27+ JSON = "json"
28+ JSONL = "jsonl"
29+
30+ class SchemaGenerator (BaseModel ):
31+ """Automatically infer a dataset schema and create a pydantic model from it."""
32+
33+ file_name : str | Path = Field (
34+ description = "The path to the dataset."
35+ )
36+
37+ fmt : FileFormat | None = Field (
38+ None , description = "The dataset file format. If no format is provided, it will be inferred."
39+ )
40+
41+ @model_validator (mode = "before" )
42+ def check_format (cls , config : dict [str ,Any ]) -> dict [str ,Any ]:
43+
44+ if isinstance (fp := config ["file_name" ],str ):
45+ config ["file_name" ] = Path (fp ).resolve ()
46+
47+ if config .get ("fmt" ):
48+ if isinstance (config ["fmt" ],str ):
49+ if config ["fmt" ] in FileFormat .__members__ :
50+ config ["fmt" ] = FileFormat [config ["fmt" ]]
51+ else :
52+ try :
53+ config ["fmt" ] = FileFormat (config ["fmt" ])
54+ except ValueError :
55+ raise ValueError (
56+ f"Could not interpret submitted file format { config ['fmt' ]} "
57+ )
58+ else :
59+ try :
60+ config ["fmt" ] = next (
61+ file_fmt for file_fmt in FileFormat if file_fmt .value in config ["file_name" ].name
62+ )
63+ except StopIteration :
64+ raise ValueError (
65+ f"Could not infer file format for { config ['file_name' ]} "
66+ )
67+ return config
68+
69+ @staticmethod
70+ def _cast_dtype (dtype , assume_nullable : bool = True ):
71+ """Cast input dtype to parquet-friendly dtypes.
72+
73+ Accounts for difficulties de-serializing datetimes
74+ and complex numbers.
75+
76+ Assumes all fields are nullable by default.
77+ """
78+ vname = getattr (dtype ,"name" ,str (dtype )).lower ()
79+
80+ if any (spec_type in vname for spec_type in ("datetime" ,"complex" )):
81+ if "datetime" in vname :
82+ return NullableDateTimeType if assume_nullable else DateTimeType
83+ elif "complex" in vname :
84+ return NullableComplexType if assume_nullable else ComplexType
85+
86+ inferred_type = str
87+ if "float" in vname :
88+ inferred_type = float
89+ elif "int" in vname :
90+ inferred_type = int
91+
92+ return inferred_type | None if assume_nullable else inferred_type
93+
94+ @property
95+ def pydantic_schema (self ) -> Type [BaseModel ]:
96+ """Create the pydantic schema of the data structure."""
97+
98+ if self .fmt == "csv" :
99+ data = pd .read_csv (self .file_name )
100+
101+ elif self .fmt in {"json" ,"jsonl" }:
102+ # we exclude the "table" case for `orient` since the user
103+ # presumably already knows what the schema is.
104+ for orient in ("columns" ,"index" ,"records" ,"split" ,"values" ):
105+ try :
106+ data = pd .read_json (self .file_name , orient = orient , lines = self .fmt == "jsonl" )
107+ break
108+ except Exception as exc :
109+ continue
110+ else :
111+ raise ValueError (
112+ f"Could not load { self .fmt .value } data, please check manually."
113+ )
114+
115+ model_fields = {
116+ col_name : (
117+ self ._cast_dtype (data .dtypes [col_name ]),
118+ Field (default = None ,)
119+ )
120+ for col_name in data .columns
121+ }
122+
123+ return create_model (
124+ f"{ self .file_name .name .split ("." ,1 )[0 ]} " ,
125+ ** model_fields ,
126+ )
0 commit comments