-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgen_models.py
More file actions
337 lines (276 loc) · 11.6 KB
/
gen_models.py
File metadata and controls
337 lines (276 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#!/usr/bin/env python3
import argparse
import json
import logging
import os
import shutil
import subprocess
import sys
from pathlib import Path
from rich.logging import RichHandler
from rich.traceback import install
# Replace the basic logging config with Rich handler
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[RichHandler(rich_tracebacks=True, markup=True)],
)
logging.getLogger("httpx").setLevel(logging.INFO)
install(show_locals=False)
logger = logging.getLogger(__name__)
class Generator:
def __init__(
self,
output_dir: str,
version: str,
verbose: bool,
):
# openapi repo
self.repo_url = "https://github.com/eda-labs/openapi"
# a dir to clone the openapi repo
self.build_dir = Path("./build")
# delete build dir if it exists before creating an empty one
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
self.build_dir.mkdir(exist_ok=True)
self.output_dir = Path(output_dir)
self.verbose = verbose
self.version = version
if self.verbose:
logger.setLevel(logging.DEBUG)
def clone_repo(self):
"""
Clone the openapi repo under the build dir
"""
subprocess.run(
["git", "clone", "-b", self.version, self.repo_url, self.build_dir],
check=True,
stderr=subprocess.DEVNULL,
)
def process_specs(self):
"""
Process the specs under the build dir
"""
# Process apps directory
apps_dir = self.build_dir.joinpath("apps")
if not apps_dir.exists():
logger.info("apps dir not found in the cloned specs repo")
sys.exit(1)
for spec_file in apps_dir.glob("**/*.json"):
# uncomment and change to a desired name if you want to generate
# just one model
# if spec_file.name != "services.json":
# continue
logger.info(f"Processing {spec_file}")
api_name, api_version = extract_name_version(spec_file)
logger.debug(f"API name: {api_name}, API version: {api_version}")
self.sanitize_schema_objects(spec_file, api_name, api_version)
self.generate_classes_for_spec(spec_file, api_name, api_version)
# process the core spec that is a single file in its own dir
core_dir = self.build_dir.joinpath("core")
if not core_dir.exists():
logger.info("core dir not found in the cloned specs repo")
sys.exit(1)
for spec_file in core_dir.glob("**/*.json"):
api_name = "core"
# core api has a v0.0.1 in the spec but that will change
# for now use the version provided by a user from the cmd
api_version = self.version.replace(".", "_").replace("-", "_")
logger.debug(f"API name: {api_name}, API version: {api_version}")
self.sanitize_schema_objects(spec_file, api_name, api_version)
self.generate_classes_for_spec(spec_file, api_name, api_version)
def generate_classes_for_spec(
self, spec_file: Path, api_name: str, api_version: str
):
"""
Generate Pydantic classes for the given sanitized spec file
:param spec_file: Path to the spec file
:param api_name: Name of the API
:param api_version: Version of the API
"""
app_parent_dir = "apps"
# when generating models for the core api we put it right
# under the pydantic_eda output dir, while all the apps
# go under pydantic_eda/apps/
if spec_file.parts[1] == "core":
app_parent_dir = ""
dest_file = self.output_dir.joinpath(
app_parent_dir, api_name, api_version, "models.py"
)
# Create all parent directories of the dest file
dest_file.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"datamodel-codegen",
"--input",
spec_file,
"--input-file-type",
"openapi",
"--openapi-scopes",
"schemas",
"--output-model-type",
"pydantic_v2.BaseModel",
# we will format manually using ruff in the venv
# "--formatters",
# "ruff-format",
"--use-annotated",
"--parent-scoped-naming",
"--collapse-root-models",
"--disable-timestamp",
"--reuse-model",
# can't use model order, since Topologies are defined before Topology
# maybe worth fixing the order in the model
# "--keep-model-order",
"--use-schema-description",
"--enum-field-as-literal",
"all",
"--output",
dest_file,
]
# the core API should use the file name as the output of the DMCG command
# the core app does not have apps dir in its URL
# if "apps" not in url_parts and module_name == "core":
# cmd[-1] = str(output_dir) + "/core.py"
try:
logger.info(f"Generating models for {spec_file}...")
# Create environment with explicit path to virtual env binaries
env = os.environ.copy()
venv_path = os.environ.get("VIRTUAL_ENV", ".venv")
venv_bin = Path(venv_path) / "bin"
# Prepend venv bin directory to PATH
current_path = env.get("PATH", "")
env["PATH"] = f"{venv_bin}:{current_path}"
subprocess.run(cmd, check=True, env=env)
# Format the generated file with ruff
logger.debug(f"Formatting {dest_file} with ruff...")
ruff_cmd = ["ruff", "format", str(dest_file)]
subprocess.run(ruff_cmd, check=True, env=env)
except subprocess.CalledProcessError as e:
logger.error(f"Error generating models for {spec_file}: {e}")
def sanitize_schema_objects(self, spec_file: Path, api_name: str, api_version: str):
"""
Sanitize schema objects by removing extra info like com.nokia.com, app name and api version
:param spec_file: Path to the spec file
"""
# Open and load the JSON file
with open(spec_file, "r") as f:
spec_data = json.load(f)
if (
"components" not in spec_data
or "schemas" not in spec_data["components"]
):
logger.info(f"No schemas found in {spec_file}")
return
for name, data in spec_data["components"]["schemas"].items():
logger.debug(f"Schema name: {name}")
# gating flag to track if any schemas were changed
# when flipped to true it means we need to write the in-mem file to disk in the output dir
modified = False
schemas = spec_data["components"]["schemas"]
new_schemas = {}
# Create new schema dictionary with renamed keys
for schema_name, schema_def in schemas.items():
# if we have a dotted module name, dmcg will create bad shit
# we need to remove the dotted parts and only keep the name
# as this will make the schema clean
# so for com.nokia.eda.services.v1alpha1.BridgeDomainList
# we will keep only BridgeDomainList
# we also need to ensure that all references to the original schema node
# are updated to the new name
if "com.nokia.eda" in schema_name:
new_name = schema_name.split(".")[-1]
logger.debug(f"Renaming schema: {schema_name} -> {new_name}")
new_schemas[new_name] = schema_def
modified = True
else:
new_schemas[schema_name] = schema_def
# Replace the original schemas with the renamed ones
if modified:
spec_data["components"]["schemas"] = new_schemas
# Remove the paths section entirely as it's not needed for model generation
# and it contains old ref links with dots in schema names
if "paths" in spec_data:
logger.debug(f"Removing paths section from {spec_file}")
del spec_data["paths"]
# Recursively update all $ref values in the entire spec
self._update_refs(spec_data, api_name, api_version)
# Write the modified spec back to the file
with open(
spec_file,
"w",
) as f:
json.dump(spec_data, f, indent=2)
logger.info(f"Wrote schema to {spec_file}")
def generate_models(self):
output_dir = self.output_dir
output_dir.mkdir(exist_ok=True)
self.clone_repo()
self.process_specs()
# for spec in specs[self.version]:
# url_parts = spec["url"].split("/")
# module_name = url_parts[-1].replace(".json", "")
def _update_refs(self, obj, api_name: str, api_version: str):
"""
Recursively update all $ref values in the object to use simplified schema names
:param obj: The object to update
"""
if isinstance(obj, dict):
for key, value in list(obj.items()):
if (
key == "$ref"
and isinstance(value, str)
and f"#/components/schemas/com.nokia.eda.{api_name}.{api_version}"
in value
):
# replace the unnecessary parts from the ref value
new_value = value.replace(
f"com.nokia.eda.{api_name}.{api_version}.", ""
)
obj[key] = new_value
elif isinstance(value, (dict, list)):
self._update_refs(value, api_name, api_version)
elif isinstance(obj, list):
for item in obj:
if isinstance(item, (dict, list)):
self._update_refs(item, api_name, api_version)
def extract_name_version(file: Path) -> tuple[str, str]:
"""Extract the API name and version from the spec file name.
Spec filename contains the build dir, e.g.
build/apps/bootstrap.eda.nokia.com/v1alpha1/bootstrap.json
This func then extracts the app/api name -> bootstrap
and api version -> v1alpha1
"""
# split the file name by the dot
parts = file.parts
# the name is the filename without the extension
name = file.stem
# the version is the second to last part
version = parts[-2]
return name, version
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Discover OpenAPI specifications and generate Pydantic models."
)
parser.add_argument(
"--output",
type=str,
default="./pydantic_eda",
help="Path to the output directory.",
)
parser.add_argument(
"--version",
type=str,
default="main",
help="openapi repo version (tag) to get the models from. Default: main",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging. Default: False",
)
args = parser.parse_args()
generator = Generator(
output_dir=args.output,
version=args.version,
verbose=args.verbose,
)
generator.generate_models()