Skip to content

Commit 57fedc5

Browse files
committed
feat: add compile check tooling and CI workflow
Signed-off-by: sduvvuri1603 <[email protected]>
1 parent a097968 commit 57fedc5

File tree

5 files changed

+563
-4
lines changed

5 files changed

+563
-4
lines changed

.github/scripts/compile_check/__init__.py

Whitespace-only changes.
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Compile and dependency validation tool for Kubeflow Pipelines components.
4+
5+
This script discovers component and pipeline modules based on the presence of
6+
`metadata.yaml` files, validates declared dependencies, and ensures each target
7+
compiles successfully with the Kubeflow Pipelines SDK.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import argparse
13+
import importlib
14+
import logging
15+
import sys
16+
import tempfile
17+
import traceback
18+
from dataclasses import dataclass, field
19+
from pathlib import Path
20+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
21+
22+
try:
23+
import yaml
24+
except ImportError as exc: # pragma: no cover - import guard
25+
sys.stderr.write(
26+
"PyYAML is required to run compile_check.py. "
27+
"Install it with `pip install pyyaml`.\n"
28+
)
29+
raise
30+
31+
try:
32+
from packaging.specifiers import SpecifierSet
33+
except ImportError: # pragma: no cover - packaging is optional
34+
SpecifierSet = None # type: ignore[assignment]
35+
36+
from kfp import compiler as pipeline_compiler
37+
from kfp.dsl import base_component
38+
from kfp.dsl import graph_component
39+
40+
41+
REPO_ROOT = Path(__file__).resolve().parents[1]
42+
43+
44+
@dataclass
45+
class MetadataTarget:
46+
"""Represents a single component or pipeline discovered from metadata."""
47+
48+
metadata_path: Path
49+
module_path: Path
50+
module_import: str
51+
tier: str
52+
target_kind: str # "component" or "pipeline"
53+
metadata: Dict
54+
55+
56+
@dataclass
57+
class ValidationResult:
58+
target: MetadataTarget
59+
success: bool
60+
compiled_objects: List[str] = field(default_factory=list)
61+
warnings: List[str] = field(default_factory=list)
62+
errors: List[str] = field(default_factory=list)
63+
64+
def add_error(self, message: str) -> None:
65+
logging.error(message)
66+
self.errors.append(message)
67+
self.success = False
68+
69+
def add_warning(self, message: str) -> None:
70+
logging.warning(message)
71+
self.warnings.append(message)
72+
73+
74+
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
75+
parser = argparse.ArgumentParser(description="Compile Kubeflow components and pipelines.")
76+
parser.add_argument(
77+
"--tier",
78+
choices=["core", "all"],
79+
default="all",
80+
help="Limit validation to core tier only or run across all core assets (default: all).",
81+
)
82+
parser.add_argument(
83+
"--path",
84+
action="append",
85+
default=[],
86+
help="Restrict validation to metadata paths under this directory. May be supplied multiple times.",
87+
)
88+
parser.add_argument(
89+
"--fail-fast",
90+
action="store_true",
91+
help="Stop at the first validation failure.",
92+
)
93+
parser.add_argument(
94+
"--include-flagless",
95+
action="store_true",
96+
help="Include targets that do not set ci.compile_check explicitly.",
97+
)
98+
parser.add_argument(
99+
"--verbose",
100+
action="store_true",
101+
help="Enable verbose logging.",
102+
)
103+
return parser.parse_args(argv)
104+
105+
106+
def configure_logging(verbose: bool) -> None:
107+
level = logging.DEBUG if verbose else logging.INFO
108+
logging.basicConfig(
109+
level=level,
110+
format="%(levelname)s: %(message)s",
111+
)
112+
113+
114+
def discover_metadata_files(tier: str) -> List[Tuple[Path, str, str]]:
115+
"""Return a list of (metadata_path, tier, target_kind)."""
116+
if tier not in ("core", "all"):
117+
return []
118+
119+
search_roots: List[Tuple[Path, str]] = [
120+
(REPO_ROOT / "components", "component"),
121+
(REPO_ROOT / "pipelines", "pipeline"),
122+
]
123+
124+
discovered: List[Tuple[Path, str, str]] = []
125+
for root, target_kind in search_roots:
126+
if not root.exists():
127+
continue
128+
for metadata_path in root.glob("**/metadata.yaml"):
129+
discovered.append((metadata_path, "core", target_kind))
130+
return discovered
131+
132+
133+
def should_include_target(
134+
metadata: Dict,
135+
include_flagless: bool,
136+
) -> bool:
137+
ci_config = metadata.get("ci") or {}
138+
if "compile_check" in ci_config:
139+
return bool(ci_config["compile_check"])
140+
return include_flagless
141+
142+
143+
def build_module_import_path(module_path: Path) -> str:
144+
relative = module_path.relative_to(REPO_ROOT)
145+
return ".".join(relative.with_suffix("").parts)
146+
147+
148+
def load_metadata(metadata_path: Path) -> Dict:
149+
with metadata_path.open("r", encoding="utf-8") as handle:
150+
data = yaml.safe_load(handle) or {}
151+
if not isinstance(data, dict):
152+
raise ValueError(f"Metadata at {metadata_path} must be a mapping.")
153+
return data
154+
155+
156+
def create_targets(
157+
discovered: Iterable[Tuple[Path, str, str]],
158+
include_flagless: bool,
159+
path_filters: Sequence[str],
160+
) -> List[MetadataTarget]:
161+
normalized_filters = [Path(p).resolve() for p in path_filters]
162+
targets: List[MetadataTarget] = []
163+
164+
for metadata_path, tier, target_kind in discovered:
165+
if normalized_filters:
166+
absolute_metadata_dir = metadata_path.parent.resolve()
167+
if not any(absolute_metadata_dir.is_relative_to(f) for f in normalized_filters):
168+
continue
169+
170+
try:
171+
metadata = load_metadata(metadata_path)
172+
except Exception as exc:
173+
logging.error("Failed to read metadata %s: %s", metadata_path, exc)
174+
continue
175+
176+
if not should_include_target(metadata, include_flagless):
177+
logging.debug("Skipping %s (compile_check disabled).", metadata_path)
178+
continue
179+
180+
module_filename = "component.py" if target_kind == "component" else "pipeline.py"
181+
module_path = metadata_path.with_name(module_filename)
182+
if not module_path.exists():
183+
logging.error("Expected module %s not found for metadata %s", module_path, metadata_path)
184+
continue
185+
186+
module_import = build_module_import_path(module_path)
187+
targets.append(
188+
MetadataTarget(
189+
metadata_path=metadata_path,
190+
module_path=module_path,
191+
module_import=module_import,
192+
tier=tier,
193+
target_kind=target_kind,
194+
metadata=metadata,
195+
)
196+
)
197+
return targets
198+
199+
200+
def find_objects(module, target_kind: str) -> List[Tuple[str, base_component.BaseComponent]]:
201+
found: List[Tuple[str, base_component.BaseComponent]] = []
202+
for attr_name in dir(module):
203+
attr = getattr(module, attr_name)
204+
if target_kind == "pipeline":
205+
if isinstance(attr, graph_component.GraphComponent):
206+
found.append((attr_name, attr))
207+
else:
208+
if isinstance(attr, base_component.BaseComponent) and not isinstance(
209+
attr, graph_component.GraphComponent
210+
):
211+
found.append((attr_name, attr))
212+
return found
213+
214+
215+
def validate_dependencies(metadata: Dict, result: ValidationResult) -> None:
216+
dependencies = metadata.get("dependencies") or {}
217+
if not isinstance(dependencies, dict):
218+
result.add_error("`dependencies` must be a mapping.")
219+
return
220+
221+
sections = [
222+
("kubeflow", "Kubeflow dependency"),
223+
("external_services", "External service dependency"),
224+
]
225+
226+
for section_key, label in sections:
227+
entries = dependencies.get(section_key, [])
228+
if not entries:
229+
continue
230+
if not isinstance(entries, list):
231+
result.add_error(f"`dependencies.{section_key}` must be a list.")
232+
continue
233+
for entry in entries:
234+
if not isinstance(entry, dict):
235+
result.add_error(f"{label} entries must be mappings: {entry!r}")
236+
continue
237+
name = entry.get("name")
238+
version = entry.get("version")
239+
if not name:
240+
result.add_error(f"{label} is missing a `name` field.")
241+
if not version:
242+
result.add_error(f"{label} for {name or '<unknown>'} is missing a `version` field.")
243+
elif SpecifierSet is not None:
244+
try:
245+
SpecifierSet(str(version))
246+
except Exception as exc:
247+
result.add_error(
248+
f"{label} for {name or '<unknown>'} has an invalid version specifier "
249+
f"{version!r}: {exc}"
250+
)
251+
else:
252+
result.add_warning(
253+
"packaging module not available; skipping validation for dependency versions."
254+
)
255+
return
256+
257+
258+
def compile_pipeline(obj: graph_component.GraphComponent, output_dir: Path) -> Path:
259+
output_path = output_dir / f"{obj.name or 'pipeline'}.json"
260+
pipeline_compiler.Compiler().compile(
261+
pipeline_func=obj,
262+
package_path=str(output_path),
263+
)
264+
return output_path
265+
266+
267+
def compile_component(obj: base_component.BaseComponent, output_dir: Path) -> Path:
268+
output_path = output_dir / f"{obj.name or 'component'}.yaml"
269+
obj.component_spec.save_to_component_yaml(str(output_path))
270+
return output_path
271+
272+
273+
def validate_target(target: MetadataTarget) -> ValidationResult:
274+
result = ValidationResult(target=target, success=True)
275+
validate_dependencies(target.metadata, result)
276+
if not result.success and result.errors:
277+
return result
278+
279+
try:
280+
if target.module_import in sys.modules:
281+
del sys.modules[target.module_import]
282+
module = importlib.import_module(target.module_import)
283+
except Exception:
284+
result.add_error(
285+
f"Failed to import module {target.module_import} defined in {target.module_path}.\n"
286+
f"{traceback.format_exc()}"
287+
)
288+
return result
289+
290+
objects = find_objects(module, target.target_kind)
291+
if not objects:
292+
result.add_error(
293+
f"No {target.target_kind} objects discovered in module {target.module_import}."
294+
)
295+
return result
296+
297+
with tempfile.TemporaryDirectory() as temp_dir:
298+
temp_path = Path(temp_dir)
299+
for attr_name, obj in objects:
300+
try:
301+
if target.target_kind == "pipeline":
302+
compiled_path = compile_pipeline(obj, temp_path)
303+
else:
304+
compiled_path = compile_component(obj, temp_path)
305+
result.compiled_objects.append(f"{attr_name} -> {compiled_path.name}")
306+
logging.debug(
307+
"Compiled %s from %s to %s",
308+
attr_name,
309+
target.module_import,
310+
compiled_path,
311+
)
312+
except Exception:
313+
result.add_error(
314+
f"Failed to compile {target.target_kind} `{attr_name}` from {target.module_import}.\n"
315+
f"{traceback.format_exc()}"
316+
)
317+
if result.errors:
318+
# stop compiling additional objects from this module to avoid noise
319+
break
320+
321+
return result
322+
323+
324+
def run_validation(args: argparse.Namespace) -> int:
325+
configure_logging(args.verbose)
326+
sys.path.insert(0, str(REPO_ROOT))
327+
328+
discovered = discover_metadata_files(args.tier)
329+
targets = create_targets(discovered, args.include_flagless, args.path)
330+
331+
if not targets:
332+
logging.info("No targets discovered for compile check.")
333+
return 0
334+
335+
results: List[ValidationResult] = []
336+
for target in targets:
337+
logging.info(
338+
"Validating %s (%s) from %s",
339+
target.metadata.get("name", target.module_import),
340+
target.target_kind,
341+
target.metadata_path,
342+
)
343+
result = validate_target(target)
344+
results.append(result)
345+
346+
if result.success:
347+
logging.info(
348+
"✓ %s compiled successfully (%s)",
349+
target.metadata.get("name", target.module_import),
350+
", ".join(result.compiled_objects) if result.compiled_objects else "no output",
351+
)
352+
else:
353+
logging.error(
354+
"✗ %s failed validation (%d error(s))",
355+
target.metadata.get("name", target.module_import),
356+
len(result.errors),
357+
)
358+
if args.fail_fast:
359+
break
360+
361+
failed = [res for res in results if not res.success]
362+
logging.info("Validation complete: %d succeeded, %d failed.", len(results) - len(failed), len(failed))
363+
364+
if failed:
365+
logging.error("Compile check failed for the targets listed above.")
366+
return 1
367+
return 0
368+
369+
370+
def main(argv: Optional[Sequence[str]] = None) -> int:
371+
args = parse_args(argv)
372+
try:
373+
return run_validation(args)
374+
finally:
375+
# Ensure repo root is removed if we inserted it.
376+
if sys.path and sys.path[0] == str(REPO_ROOT):
377+
sys.path.pop(0)
378+
379+
380+
if __name__ == "__main__": # pragma: no cover - CLI entry point
381+
sys.exit(main())
382+

0 commit comments

Comments
 (0)