Skip to content

Commit 50547e8

Browse files
Reformat aarch64_builtins_test_generator.py
1 parent c0320ff commit 50547e8

File tree

1 file changed

+130
-48
lines changed

1 file changed

+130
-48
lines changed

clang/utils/aarch64_builtins_test_generator.py

Lines changed: 130 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,19 @@
2626
from pathlib import Path
2727
from typing import Dict, Iterable, List, Sequence, Tuple
2828

29+
2930
# Are we testing arm_sve.h or arm_sme.h based builtins.
3031
class Mode(Enum):
3132
SVE = "sve"
3233
SME = "sme"
3334

35+
3436
class FunctionType(Enum):
3537
NORMAL = "normal"
3638
STREAMING = "streaming"
3739
STREAMING_COMPATIBLE = "streaming-compatible"
3840

41+
3942
# Builtins are grouped by their required features.
4043
@dataclass(frozen=True, order=True, slots=True)
4144
class BuiltinContext:
@@ -44,16 +47,19 @@ class BuiltinContext:
4447
flags: tuple[str, ...]
4548

4649
def __str__(self) -> str:
47-
return (f'// Properties: '
48-
f'guard="{self.guard}" '
49-
f'streaming_guard="{self.streaming_guard}" '
50-
f'flags="{",".join(self.flags)}"')
50+
return (
51+
f"// Properties: "
52+
f'guard="{self.guard}" '
53+
f'streaming_guard="{self.streaming_guard}" '
54+
f'flags="{",".join(self.flags)}"'
55+
)
5156

5257
@classmethod
5358
def from_json(cls, obj: dict[str, Any]) -> "BuiltinContext":
5459
flags = tuple(p.strip() for p in obj["flags"].split(",") if p.strip())
5560
return cls(obj["guard"], obj["streaming_guard"], flags)
5661

62+
5763
# --- Parsing builtins -------------------------------------------------------
5864

5965
# Captures the full function *declaration* inside the builtin string, e.g.:
@@ -65,6 +71,7 @@ def from_json(cls, obj: dict[str, Any]) -> "BuiltinContext":
6571
# Pulls the final word out of the left side (the function name).
6672
NAME_RE = re.compile(r"([a-zA-Z_][\w]*)\s*$")
6773

74+
6875
def parse_builtin_declaration(decl: str) -> Tuple[str, List[str]]:
6976
"""Return (func_name, param_types) from a builtin declaration string.
7077
@@ -88,10 +95,11 @@ def parse_builtin_declaration(decl: str) -> Tuple[str, List[str]]:
8895
param_types: List[str] = []
8996
else:
9097
# Split by commas respecting no pointers/arrays with commas (not expected here)
91-
param_types = [p.strip() for p in params.split(',') if p.strip()]
98+
param_types = [p.strip() for p in params.split(",") if p.strip()]
9299

93100
return func_name, param_types
94101

102+
95103
# --- Variable synthesis -----------------------------------------------------
96104

97105
# Pick a safe (ideally non-zero) value for literal types
@@ -125,8 +133,9 @@ def parse_builtin_declaration(decl: str) -> Tuple[str, List[str]]:
125133
"ImmCheckShiftRight": "2",
126134
"enum svpattern": "SV_MUL3",
127135
"enum svprfop": "SV_PSTL1KEEP",
128-
"void": ""
129-
}
136+
"void": "",
137+
}
138+
130139

131140
def make_arg_for_type(ty: str) -> Tuple[str, str]:
132141
"""Return (var_decl, var_use) for a parameter type.
@@ -142,18 +151,21 @@ def make_arg_for_type(ty: str) -> Tuple[str, str]:
142151
name = ty.replace(" ", "_").replace("*", "ptr") + "_val"
143152
return f"{ty} {name};", name
144153

154+
145155
# NOTE: Parsing is limited to the minimum required for guard strings.
146156
# Specifically the expected input is of the form:
147157
# feat1,feat2,...(feat3 | feat4 | ...),...
148-
def expand_feature_guard(guard: str, flags: str, base_feature: str = None) -> list[set[str]]:
158+
def expand_feature_guard(
159+
guard: str, flags: str, base_feature: str = None
160+
) -> list[set[str]]:
149161
"""
150162
Expand a guard expression where ',' = AND and '|' = OR, with parentheses
151163
grouping OR-expressions. Returns a list of feature sets.
152164
"""
153165
if not guard:
154-
return [];
166+
return []
155167

156-
parts = re.split(r',(?![^(]*\))', guard)
168+
parts = re.split(r",(?![^(]*\))", guard)
157169

158170
choices_per_part = []
159171
for part in parts:
@@ -185,9 +197,11 @@ def expand_feature_guard(guard: str, flags: str, base_feature: str = None) -> li
185197

186198
return unique
187199

200+
188201
def cc1_args_for_features(features: set[str]) -> str:
189202
return " ".join("-target-feature +" + s for s in sorted(features))
190203

204+
191205
def sanitise_guard(s: str) -> str:
192206
"""Rewrite guard strings in a form more suitable for file naming."""
193207
replacements = {
@@ -203,6 +217,7 @@ def sanitise_guard(s: str) -> str:
203217
s = re.sub(r"_+", "_", s)
204218
return s.strip("_")
205219

220+
206221
def make_filename(prefix: str, ctx: BuiltinContext, ext: str) -> str:
207222
parts = [sanitise_guard(ctx.guard), sanitise_guard(ctx.streaming_guard)]
208223
sanitised_guard = "___".join(p for p in parts if p)
@@ -218,22 +233,28 @@ def make_filename(prefix: str, ctx: BuiltinContext, ext: str) -> str:
218233

219234
return f"{prefix}_{sanitised_guard}{ext}"
220235

236+
221237
# --- Code Generation --------------------------------------------------------
222238

239+
223240
def emit_streaming_guard_run_lines(ctx: BuiltinContext) -> str:
224241
"""Emit lit RUN lines that will exercise the relevent Sema diagnistics."""
225242
run_prefix = "// RUN: %clang_cc1 %s -fsyntax-only -triple aarch64-none-linux-gnu"
226243
out: List[str] = []
227244

228245
# All RUN lines have SVE and SME enabled
229246
guard_features = expand_feature_guard(ctx.guard, ctx.flags, "sme")
230-
streaming_guard_features = expand_feature_guard(ctx.streaming_guard, ctx.flags, "sve")
247+
streaming_guard_features = expand_feature_guard(
248+
ctx.streaming_guard, ctx.flags, "sve"
249+
)
231250

232251
if "streaming-only" in ctx.flags:
233252
assert not guard_features
234253
# Generate RUN lines for features only availble to streaming functions
235254
for feats in streaming_guard_features:
236-
out.append(f"{run_prefix} {cc1_args_for_features(feats)} -verify=streaming-guard")
255+
out.append(
256+
f"{run_prefix} {cc1_args_for_features(feats)} -verify=streaming-guard"
257+
)
237258
elif "streaming-compatible" in ctx.flags:
238259
assert not guard_features
239260
# NOTE: Streaming compatible builtins don't require SVE.
@@ -243,7 +264,9 @@ def emit_streaming_guard_run_lines(ctx: BuiltinContext) -> str:
243264
out.append("// expected-no-diagnostics")
244265
elif "feature-dependent" in ctx.flags:
245266
assert guard_features and streaming_guard_features
246-
combined_features = expand_feature_guard(ctx.guard + "," + ctx.streaming_guard, ctx.flags)
267+
combined_features = expand_feature_guard(
268+
ctx.guard + "," + ctx.streaming_guard, ctx.flags
269+
)
247270

248271
# Generate RUN lines for features only availble to normal functions
249272
for feats in guard_features:
@@ -253,7 +276,9 @@ def emit_streaming_guard_run_lines(ctx: BuiltinContext) -> str:
253276
# Geneate RUN lines for features only available to streaming functions
254277
for feats in streaming_guard_features:
255278
if feats not in combined_features:
256-
out.append(f"{run_prefix} {cc1_args_for_features(feats)} -verify=streaming-guard")
279+
out.append(
280+
f"{run_prefix} {cc1_args_for_features(feats)} -verify=streaming-guard"
281+
)
257282

258283
# Generate RUN lines for features available to all functions
259284
for feats in combined_features:
@@ -268,7 +293,14 @@ def emit_streaming_guard_run_lines(ctx: BuiltinContext) -> str:
268293

269294
return "\n".join(out)
270295

271-
def emit_streaming_guard_function(ctx: BuiltinContext, var_decls: Sequence[str], calls: Sequence[str], func_name: str, func_type: FunctionType = FunctionType.NORMAL) -> str:
296+
297+
def emit_streaming_guard_function(
298+
ctx: BuiltinContext,
299+
var_decls: Sequence[str],
300+
calls: Sequence[str],
301+
func_name: str,
302+
func_type: FunctionType = FunctionType.NORMAL,
303+
) -> str:
272304
"""Emit a C function calling all builtins.
273305
274306
`calls` is a sequence of tuples: (name, call_line)
@@ -279,17 +311,23 @@ def emit_streaming_guard_function(ctx: BuiltinContext, var_decls: Sequence[str],
279311
if func_type != FunctionType.STREAMING:
280312
require_streaming_diagnostic = True
281313
elif "streaming-compatible" in ctx.flags:
282-
pass # streaming compatible builtins are always available
314+
pass # streaming compatible builtins are always available
283315
elif "feature-dependent" in ctx.flags:
284316
guard_features = expand_feature_guard(ctx.guard, ctx.flags, "sme")
285-
streaming_guard_features = expand_feature_guard(ctx.streaming_guard, ctx.flags, "sve")
286-
combined_features = expand_feature_guard(ctx.guard + "," + ctx.streaming_guard, ctx.flags)
317+
streaming_guard_features = expand_feature_guard(
318+
ctx.streaming_guard, ctx.flags, "sve"
319+
)
320+
combined_features = expand_feature_guard(
321+
ctx.guard + "," + ctx.streaming_guard, ctx.flags
322+
)
287323

288324
if func_type != FunctionType.NORMAL:
289325
if any(feats not in combined_features for feats in guard_features):
290326
require_diagnostic = True
291327
if func_type != FunctionType.STREAMING:
292-
if any(feats not in combined_features for feats in streaming_guard_features):
328+
if any(
329+
feats not in combined_features for feats in streaming_guard_features
330+
):
293331
require_streaming_diagnostic = True
294332
else:
295333
if func_type != FunctionType.NORMAL:
@@ -319,26 +357,35 @@ def emit_streaming_guard_function(ctx: BuiltinContext, var_decls: Sequence[str],
319357
# Emit calls
320358
for call in calls:
321359
if require_diagnostic and require_streaming_diagnostic:
322-
out.append(" // guard-error@+2 {{builtin can only be called from a non-streaming function}}")
323-
out.append(" // streaming-guard-error@+1 {{builtin can only be called from a streaming function}}")
360+
out.append(
361+
" // guard-error@+2 {{builtin can only be called from a non-streaming function}}"
362+
)
363+
out.append(
364+
" // streaming-guard-error@+1 {{builtin can only be called from a streaming function}}"
365+
)
324366
elif require_diagnostic:
325-
out.append(" // guard-error@+1 {{builtin can only be called from a non-streaming function}}")
367+
out.append(
368+
" // guard-error@+1 {{builtin can only be called from a non-streaming function}}"
369+
)
326370
elif require_streaming_diagnostic:
327-
out.append(" // streaming-guard-error@+1 {{builtin can only be called from a streaming function}}")
371+
out.append(
372+
" // streaming-guard-error@+1 {{builtin can only be called from a streaming function}}"
373+
)
328374
out.append(f" {call}")
329375

330376
out.append("}")
331377
return "\n".join(out) + "\n"
332378

379+
333380
def natural_key(s: str):
334381
"""Allow sorting akin to "sort -V"""
335-
return [int(text) if text.isdigit() else text
336-
for text in re.split(r'(\d+)', s)]
382+
return [int(text) if text.isdigit() else text for text in re.split(r"(\d+)", s)]
383+
337384

338385
def build_calls_for_group(builtins: Iterable[str]) -> Tuple[List[str], List[str]]:
339386
"""From a list of builtin declaration strings, produce:
340-
- a sorted list of unique variable declarations
341-
- a sorted list of builtin calls
387+
- a sorted list of unique variable declarations
388+
- a sorted list of builtin calls
342389
"""
343390
var_decls: List[str] = []
344391
seen_types: set[str] = set()
@@ -363,6 +410,7 @@ def build_calls_for_group(builtins: Iterable[str]) -> Tuple[List[str], List[str]
363410

364411
return var_decls, calls
365412

413+
366414
def gen_streaming_guard_tests(mode: MODE, json_path: Path, out_dir: Path) -> None:
367415
"""Generate a set of Clang Sema test files to ensure SVE/SME builtins are
368416
callable based on the function type, or the required diagnostic is emitted.
@@ -381,8 +429,10 @@ def gen_streaming_guard_tests(mode: MODE, json_path: Path, out_dir: Path) -> Non
381429
for builtin_ctx, builtin_decls in by_guard.items():
382430
var_decls, calls = build_calls_for_group(builtin_decls)
383431

384-
out_parts: List[str] = [];
385-
out_parts.append("// NOTE: File has been autogenerated by utils/aarch64_builtins_test_generator.py")
432+
out_parts: List[str] = []
433+
out_parts.append(
434+
"// NOTE: File has been autogenerated by utils/aarch64_builtins_test_generator.py"
435+
)
386436
out_parts.append(emit_streaming_guard_run_lines(builtin_ctx))
387437
out_parts.append("")
388438
out_parts.append("// REQUIRES: aarch64-registered-target")
@@ -391,9 +441,23 @@ def gen_streaming_guard_tests(mode: MODE, json_path: Path, out_dir: Path) -> Non
391441
out_parts.append("")
392442
out_parts.append(str(builtin_ctx))
393443
out_parts.append("")
394-
out_parts.append(emit_streaming_guard_function(builtin_ctx, var_decls, calls, "test"))
395-
out_parts.append(emit_streaming_guard_function(builtin_ctx, var_decls, calls, "test_streaming", FunctionType.STREAMING))
396-
out_parts.append(emit_streaming_guard_function(builtin_ctx, var_decls, calls, "test_streaming_compatible", FunctionType.STREAMING_COMPATIBLE))
444+
out_parts.append(
445+
emit_streaming_guard_function(builtin_ctx, var_decls, calls, "test")
446+
)
447+
out_parts.append(
448+
emit_streaming_guard_function(
449+
builtin_ctx, var_decls, calls, "test_streaming", FunctionType.STREAMING
450+
)
451+
)
452+
out_parts.append(
453+
emit_streaming_guard_function(
454+
builtin_ctx,
455+
var_decls,
456+
calls,
457+
"test_streaming_compatible",
458+
FunctionType.STREAMING_COMPATIBLE,
459+
)
460+
)
397461

398462
output = "\n".join(out_parts).rstrip() + "\n"
399463

@@ -402,42 +466,59 @@ def gen_streaming_guard_tests(mode: MODE, json_path: Path, out_dir: Path) -> Non
402466
filename = make_filename(f"arm_{mode.value}", builtin_ctx, ".c")
403467
(out_dir / filename).write_text(output)
404468
else:
405-
print(output)
469+
print(output)
406470

407471
return 0
408472

473+
409474
# --- Main -------------------------------------------------------------------
410475

476+
411477
def existing_file(path: str) -> Path:
412478
p = Path(path)
413479
if not p.is_file():
414480
raise argparse.ArgumentTypeError(f"{p} is not a valid file")
415481
return p
416482

483+
417484
def main(argv: Sequence[str] | None = None) -> int:
418485
ap = argparse.ArgumentParser(description="Emit C tests for SVE/SME builtins")
419-
ap.add_argument("json", type=existing_file,
420-
help="Path to json formatted builtin descriptions")
421-
ap.add_argument("--out-dir", type=Path, default=None,
422-
help="Output directory (default: stdout)")
423-
ap.add_argument("--gen-streaming-guard-tests", action="store_true",
424-
help="Generate C tests to validate SVE/SME builtin usage base on streaming attribute")
425-
ap.add_argument("--gen-target-guard-tests", action="store_true",
426-
help="Generate C tests to validate SVE/SME builtin usage based on target features")
427-
ap.add_argument("--gen-builtin-tests", action="store_true",
428-
help="Generate C tests to exercise SVE/SME builtins")
429-
ap.add_argument("--base-target-feature", choices=["sve", "sme"],
430-
help="Force builtin source (sve: arm_sve.h, sme: arm_sme.h)")
486+
ap.add_argument(
487+
"json", type=existing_file, help="Path to json formatted builtin descriptions"
488+
)
489+
ap.add_argument(
490+
"--out-dir", type=Path, default=None, help="Output directory (default: stdout)"
491+
)
492+
ap.add_argument(
493+
"--gen-streaming-guard-tests",
494+
action="store_true",
495+
help="Generate C tests to validate SVE/SME builtin usage base on streaming attribute",
496+
)
497+
ap.add_argument(
498+
"--gen-target-guard-tests",
499+
action="store_true",
500+
help="Generate C tests to validate SVE/SME builtin usage based on target features",
501+
)
502+
ap.add_argument(
503+
"--gen-builtin-tests",
504+
action="store_true",
505+
help="Generate C tests to exercise SVE/SME builtins",
506+
)
507+
ap.add_argument(
508+
"--base-target-feature",
509+
choices=["sve", "sme"],
510+
help="Force builtin source (sve: arm_sve.h, sme: arm_sme.h)",
511+
)
431512

432513
args = ap.parse_args(argv)
433514

434515
# When not forced, try to infer the mode from the input, defaulting to SVE.
435516
if args.base_target_feature:
436-
mode=Mode(args.base_target_feature)
517+
mode = Mode(args.base_target_feature)
437518
elif args.json and args.json.name == "arm_sme_builtins.json":
438-
mode=Mode.SME
519+
mode = Mode.SME
439520
else:
440-
mode=Mode.SVE
521+
mode = Mode.SVE
441522

442523
# Generate test file
443524
if args.gen_streaming_guard_tests:
@@ -449,5 +530,6 @@ def main(argv: Sequence[str] | None = None) -> int:
449530

450531
return 0
451532

533+
452534
if __name__ == "__main__":
453535
raise SystemExit(main())

0 commit comments

Comments
 (0)