Skip to content

Commit 289ff17

Browse files
authored
Cleanup formatter module (#1026)
And fixes a bug with keywords for examples.
1 parent 236a3a1 commit 289ff17

File tree

7 files changed

+274
-346
lines changed

7 files changed

+274
-346
lines changed

.generator/conftest.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,8 @@
1515
from pytest_bdd import given, parsers, then, when
1616

1717
from generator import openapi
18-
from generator.utils import camel_case, given_variables, snake_case, untitle_case
1918

20-
from generator.python_formatter import (
21-
format_parameters as format_parameters_python,
22-
format_data_with_schema as format_data_with_schema_python,
23-
safe_snake_case,
24-
)
19+
from generator.formatter import format_parameters, format_data_with_schema, safe_snake_case, snake_case
2520

2621

2722
MODIFIED_FEATURES = {
@@ -84,28 +79,18 @@ def encode(self, obj):
8479
return result
8580

8681

87-
def json_dumps(*args, **kwargs):
88-
if CLIENT_REPO_NAME == "terraform-config":
89-
kwargs.setdefault("cls", FloatEncoder)
90-
return json.dumps(*args, **kwargs)
91-
92-
9382
JINJA_ENV = Environment(loader=FileSystemLoader(pathlib.Path(__file__).parent / "src" / "generator" / "templates"))
9483
JINJA_ENV.filters["tojson"] = json.dumps
9584
JINJA_ENV.filters["snake_case"] = snake_case
9685
JINJA_ENV.filters["safe_snake_case"] = safe_snake_case
97-
JINJA_ENV.filters["camel_case"] = camel_case
98-
JINJA_ENV.filters["untitle_case"] = untitle_case
99-
JINJA_ENV.globals["format_data_with_schema_python"] = format_data_with_schema_python
100-
JINJA_ENV.globals["format_parameters_python"] = format_parameters_python
101-
JINJA_ENV.globals["given_variables"] = given_variables
86+
JINJA_ENV.globals["format_data_with_schema"] = format_data_with_schema
87+
JINJA_ENV.globals["format_parameters"] = format_parameters
10288

10389
PYTHON_EXAMPLE_J2 = JINJA_ENV.get_template("example.j2")
10490

10591

10692
def pytest_bdd_after_scenario(request, feature, scenario):
10793
try:
108-
specs = request.getfixturevalue("specs")
10994
operation_specs = request.getfixturevalue("operation_specs")
11095
version = request.getfixturevalue("api_version")
11196
context = request.getfixturevalue("context")
@@ -547,7 +532,7 @@ def the_status_is(context, status, description):
547532
@then(parsers.parse('the response "{response_path}" is equal to {value}'))
548533
def expect_equal(context, response_path, value):
549534
"""Compare a response attribute to a value."""
550-
pass
535+
551536

552537
@then(
553538
parsers.parse(
@@ -556,16 +541,13 @@ def expect_equal(context, response_path, value):
556541
)
557542
def expect_equal_value(context, response_path, fixture_path):
558543
"""Compare a response attribute to another attribute."""
559-
pass
560544

561545

562546
@then(parsers.parse('the response "{response_path}" has length {fixture_length:d}'))
563547
def expect_equal_length(context, response_path, fixture_length):
564548
"""Check the length of a response attribute."""
565-
pass
566549

567550

568551
@then(parsers.parse('the response "{response_path}" is false'))
569552
def expect_false(context, response_path):
570553
"""Check that a response attribute is false."""
571-
pass

.generator/src/generator/formatter.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
11
"""Data formatter."""
2+
from collections import defaultdict
3+
import json
4+
from functools import singledispatch
5+
import pathlib
26
import keyword
7+
import warnings
38
import re
49

10+
import dateutil.parser
511
import m2r2
612

13+
14+
MODEL_IMPORT_TPL = "datadog_api_client.{version}.model.{name}"
15+
16+
EDGE_CASES = {}
17+
replacement_file = (
18+
pathlib.Path(__file__).parent
19+
/ "replacement.json"
20+
)
21+
if replacement_file.exists():
22+
with replacement_file.open() as f:
23+
EDGE_CASES.update(json.load(f))
24+
725
KEYWORDS = set(keyword.kwlist)
826
KEYWORDS.add("property")
927

@@ -21,6 +39,12 @@ def snake_case(value):
2139
return PATTERN_DOUBLE_UNDERSCORE.sub("_", s1)
2240

2341

42+
def safe_snake_case(value):
43+
for token, replacement in EDGE_CASES.items():
44+
value = value.replace(token, replacement)
45+
return snake_case(value)
46+
47+
2448
def camel_case(value):
2549
return "".join(x.title() for x in snake_case(value).split("_"))
2650

@@ -71,3 +95,245 @@ def header(self, text, level, raw=None):
7195

7296
def docstring(text):
7397
return m2r2.convert(text.replace("\\n", "\\\\n"), renderer=CustomRenderer())[1:-1].replace("\\ ", " ")
98+
99+
100+
def _merge_imports(a, b):
101+
"""Merge second set of imports into first one."""
102+
for k, v in b.items():
103+
a[k] |= v
104+
return a
105+
106+
107+
def format_parameters(kwargs, spec, version, replace_values=None):
108+
parameters = ""
109+
imports = defaultdict(set)
110+
111+
parameters_spec = {p["name"]: p for p in spec.get("parameters", [])}
112+
if (
113+
"requestBody" in spec
114+
and "multipart/form-data" in spec["requestBody"]["content"]
115+
):
116+
parent = spec["requestBody"]["content"]["multipart/form-data"]["schema"]
117+
for name, schema in parent["properties"].items():
118+
parameters_spec[name] = {
119+
"in": "form",
120+
"schema": schema,
121+
"name": name,
122+
"description": schema.get("description"),
123+
"required": name in parent.get("required", []),
124+
}
125+
126+
parameters = ""
127+
for p in parameters_spec.values():
128+
k = p["name"]
129+
if k not in kwargs:
130+
continue
131+
132+
v = kwargs[k]
133+
value, extra_imports = format_data_with_schema(
134+
v["value"],
135+
p["schema"],
136+
replace_values=replace_values,
137+
version=version,
138+
)
139+
imports = _merge_imports(imports, extra_imports)
140+
parameters += f"{escape_reserved_keyword(safe_snake_case(k))}={value}, "
141+
142+
return parameters, imports
143+
144+
145+
def get_name_and_imports(schema, version=None, imports=None):
146+
assert version is not None
147+
imports = imports or defaultdict(set)
148+
149+
name = None
150+
if hasattr(schema, "__reference__"):
151+
name = schema.__reference__["$ref"].split("/")[-1]
152+
if "oneOf" not in schema:
153+
# do not include parent of oneOf schema
154+
imports[
155+
MODEL_IMPORT_TPL.format(version=version, name=safe_snake_case(name))
156+
].add(name)
157+
158+
return name, imports
159+
160+
161+
@singledispatch
162+
def format_data_with_schema(
163+
data,
164+
schema,
165+
replace_values=None,
166+
default_name=None,
167+
version=None,
168+
imports=None,
169+
):
170+
"""Format data with schema."""
171+
assert version is not None
172+
173+
name = None
174+
imports = imports or defaultdict(set)
175+
if schema.get("type") not in {"string", "integer", "boolean"} or schema.get("enum"):
176+
name, imports = get_name_and_imports(schema, version, imports)
177+
if name:
178+
imports[
179+
MODEL_IMPORT_TPL.format(version=version, name=safe_snake_case(name))
180+
].add(name)
181+
182+
if "enum" in schema and data not in schema["enum"]:
183+
raise ValueError(f"{data} is not valid enum value {schema['enum']}")
184+
185+
if replace_values and data in replace_values:
186+
parameters = replace_values[data]
187+
if schema.get("format") in ("int32", "int64"):
188+
parameters = f"int({parameters})"
189+
else:
190+
if schema.get("nullable") and data is None:
191+
parameters = repr(data)
192+
else:
193+
194+
def format_datetime(x):
195+
imports["datetime"].add("datetime")
196+
d = dateutil.parser.isoparse(x)
197+
result = repr(d)
198+
if result.startswith("datetime."):
199+
result = result[len("datetime.") :]
200+
if "tzutc" in result:
201+
imports["dateutil.tz"].add("tzutc")
202+
return result
203+
204+
formatter = {
205+
"double": lambda s: repr(float(s)),
206+
"int32": lambda s: repr(int(s)),
207+
"int64": lambda s: repr(int(s)),
208+
"date": format_datetime,
209+
"date-time": format_datetime,
210+
"binary": lambda s: f'open("{s}", "rb")',
211+
"email": repr,
212+
None: repr,
213+
}[schema.get("format")]
214+
215+
# TODO format date and datetime
216+
parameters = formatter(data)
217+
218+
if name:
219+
return f"{name}({parameters})", imports
220+
221+
return parameters, imports
222+
223+
224+
@format_data_with_schema.register(list)
225+
def format_data_with_schema_list(
226+
data,
227+
schema,
228+
replace_values=None,
229+
default_name=None,
230+
version=None,
231+
imports=None,
232+
):
233+
"""Format data with schema."""
234+
assert version is not None
235+
name, imports = get_name_and_imports(schema, version, imports)
236+
237+
parameters = ""
238+
for d in data:
239+
value, extra_imports = format_data_with_schema(
240+
d,
241+
schema["items"],
242+
replace_values=replace_values,
243+
default_name=name,
244+
version=version,
245+
)
246+
parameters += f"{value}, "
247+
imports = _merge_imports(imports, extra_imports)
248+
parameters = f"[{parameters}]"
249+
250+
if name:
251+
return f"{name}({parameters})", imports
252+
253+
return parameters, imports
254+
255+
256+
@format_data_with_schema.register(dict)
257+
def format_data_with_schema_dict(
258+
data,
259+
schema,
260+
replace_values=None,
261+
default_name=None,
262+
version=None,
263+
imports=None,
264+
):
265+
"""Format data with schema."""
266+
assert version is not None
267+
name, imports = get_name_and_imports(schema, version, imports)
268+
269+
parameters = ""
270+
if "properties" in schema:
271+
for k, v in data.items():
272+
if k in schema["properties"]:
273+
sub_schema = schema["properties"][k]
274+
else:
275+
sub_schema = schema["additionalProperties"]
276+
value, extra_imports = format_data_with_schema(
277+
v,
278+
sub_schema,
279+
replace_values=replace_values,
280+
default_name=name + camel_case(k) if name else None,
281+
version=version,
282+
)
283+
parameters += f"{escape_reserved_keyword(safe_snake_case(k))}={value}, "
284+
imports = _merge_imports(imports, extra_imports)
285+
286+
if schema.get("additionalProperties") and not schema.get("properties"):
287+
for k, v in data.items():
288+
value, extra_imports = format_data_with_schema(
289+
v,
290+
schema["additionalProperties"],
291+
replace_values=replace_values,
292+
version=version,
293+
)
294+
parameters += f"{escape_reserved_keyword(k)}={value}, "
295+
imports = _merge_imports(imports, extra_imports)
296+
297+
if not name and "oneOf" not in schema:
298+
if (
299+
default_name
300+
and not schema.get("additionalProperties")
301+
and schema.get("properties")
302+
):
303+
name = default_name
304+
imports[
305+
MODEL_IMPORT_TPL.format(version=version, name=safe_snake_case(name))
306+
].add(name)
307+
else:
308+
name = "dict"
309+
warnings.warn(f"Unnamed schema {schema} for {data}")
310+
311+
if "oneOf" in schema:
312+
matched = 0
313+
for sub_schema in schema["oneOf"]:
314+
try:
315+
formatted, extra_imports = format_data_with_schema(
316+
data,
317+
sub_schema,
318+
replace_values=replace_values,
319+
version=version,
320+
)
321+
if matched == 0:
322+
imports = _merge_imports(imports, extra_imports)
323+
# NOTE we do not support mixed schemas with oneOf
324+
# parameters += formatted
325+
parameters = formatted
326+
name = None
327+
matched += 1
328+
except (KeyError, ValueError) as e:
329+
print(f"{e}")
330+
331+
if matched == 0:
332+
raise ValueError(f"[{matched}] {data} is not valid for schema {name}")
333+
elif matched > 1:
334+
warnings.warn(f"[{matched}] {data} is not valid for schema {name}")
335+
336+
if name:
337+
return f"{name}({parameters})", imports
338+
339+
return parameters, imports

0 commit comments

Comments
 (0)