|
24 | 24 |
|
25 | 25 | import ast |
26 | 26 | import os |
| 27 | +import glob |
| 28 | +import logging |
| 29 | +import re |
27 | 30 | from collections import defaultdict |
28 | 31 | from typing import List, Dict, Any, Iterator |
29 | 32 |
|
@@ -65,7 +68,7 @@ def _get_type_str(self, node: ast.AST | None) -> str | None: |
65 | 68 | if isinstance(curr, ast.Name): |
66 | 69 | parts.append(curr.id) |
67 | 70 | return ".".join(reversed(parts)) |
68 | | - # Handles subscripted types like 'list[str]', 'Optional[...]' |
| 71 | + # Handles subscripted types like 'list[str]', 'Optional[...]' |
69 | 72 | if isinstance(node, ast.Subscript): |
70 | 73 | value_str = self._get_type_str(node.value) |
71 | 74 | slice_str = self._get_type_str(node.slice) |
@@ -352,3 +355,138 @@ def process_structure( |
352 | 355 | return sorted(all_class_keys) |
353 | 356 | else: |
354 | 357 | return dict(results) |
| 358 | + |
| 359 | + |
| 360 | +# ============================================================================= |
| 361 | +# Section 2: Source file data gathering |
| 362 | +# ============================================================================= |
| 363 | + |
| 364 | + |
| 365 | +def _should_include_class(class_name: str, class_filters: Dict[str, Any]) -> bool: |
| 366 | + """Checks if a class should be included based on filter criteria.""" |
| 367 | + if class_filters.get("include_suffixes"): |
| 368 | + if not class_name.endswith(tuple(class_filters["include_suffixes"])): |
| 369 | + return False |
| 370 | + if class_filters.get("exclude_suffixes"): |
| 371 | + if class_name.endswith(tuple(class_filters["exclude_suffixes"])): |
| 372 | + return False |
| 373 | + return True |
| 374 | + |
| 375 | + |
| 376 | +def _should_include_method(method_name: str, method_filters: Dict[str, Any]) -> bool: |
| 377 | + """Checks if a method should be included based on filter criteria.""" |
| 378 | + if method_filters.get("include_prefixes"): |
| 379 | + if not any( |
| 380 | + method_name.startswith(p) for p in method_filters["include_prefixes"] |
| 381 | + ): |
| 382 | + return False |
| 383 | + if method_filters.get("exclude_prefixes"): |
| 384 | + if any(method_name.startswith(p) for p in method_filters["exclude_prefixes"]): |
| 385 | + return False |
| 386 | + return True |
| 387 | + |
| 388 | + |
| 389 | +def _build_request_arg_schema( |
| 390 | + source_files: List[str], project_root: str |
| 391 | +) -> Dict[str, List[str]]: |
| 392 | + """Parses type files to build a schema of request classes and their _id arguments.""" |
| 393 | + request_arg_schema: Dict[str, List[str]] = {} |
| 394 | + for file_path in source_files: |
| 395 | + if "/types/" not in file_path: |
| 396 | + continue |
| 397 | + |
| 398 | + # Correctly determine the module name from the file path |
| 399 | + relative_path = os.path.relpath(file_path, project_root) |
| 400 | + module_name = os.path.splitext(relative_path)[0].replace(os.path.sep, ".") |
| 401 | + |
| 402 | + try: |
| 403 | + structure, _, _ = parse_file(file_path) |
| 404 | + if not structure: |
| 405 | + continue |
| 406 | + |
| 407 | + for class_info in structure: |
| 408 | + class_name = class_info.get("class_name", "Unknown") |
| 409 | + if class_name.endswith("Request"): |
| 410 | + full_class_name = f"{module_name}.{class_name}" |
| 411 | + id_args = [ |
| 412 | + attr["name"] |
| 413 | + for attr in class_info.get("attributes", []) |
| 414 | + if attr.get("name", "").endswith("_id") |
| 415 | + ] |
| 416 | + if id_args: |
| 417 | + request_arg_schema[full_class_name] = id_args |
| 418 | + except Exception as e: |
| 419 | + logging.warning(f"Failed to parse {file_path}: {e}") |
| 420 | + return request_arg_schema |
| 421 | + |
| 422 | + |
| 423 | +def _process_service_clients( |
| 424 | + source_files: List[str], class_filters: Dict, method_filters: Dict |
| 425 | +) -> tuple[defaultdict, set, set]: |
| 426 | + """Parses service client files to extract class and method information.""" |
| 427 | + parsed_data = defaultdict(dict) |
| 428 | + all_imports: set[str] = set() |
| 429 | + all_types: set[str] = set() |
| 430 | + |
| 431 | + for file_path in source_files: |
| 432 | + if "/services/" not in file_path: |
| 433 | + continue |
| 434 | + |
| 435 | + structure, imports, types = parse_file(file_path) |
| 436 | + all_imports.update(imports) |
| 437 | + all_types.update(types) |
| 438 | + |
| 439 | + for class_info in structure: |
| 440 | + class_name = class_info["class_name"] |
| 441 | + if not _should_include_class(class_name, class_filters): |
| 442 | + continue |
| 443 | + |
| 444 | + parsed_data[class_name] # Ensure class is in dict |
| 445 | + |
| 446 | + for method in class_info["methods"]: |
| 447 | + method_name = method["method_name"] |
| 448 | + if not _should_include_method(method_name, method_filters): |
| 449 | + continue |
| 450 | + parsed_data[class_name][method_name] = method |
| 451 | + return parsed_data, all_imports, all_types |
| 452 | + |
| 453 | + |
| 454 | +def analyze_source_files( |
| 455 | + config: Dict[str, Any], |
| 456 | +) -> tuple[Dict[str, Any], set[str], set[str], Dict[str, List[str]]]: |
| 457 | + """ |
| 458 | + Analyzes source files per the configuration to extract class and method info, |
| 459 | + as well as information on imports and typehints. |
| 460 | +
|
| 461 | + Args: |
| 462 | + config: The generator's configuration dictionary. |
| 463 | +
|
| 464 | + Returns: |
| 465 | + A tuple containing: |
| 466 | + - A dictionary containing the data needed for template rendering. |
| 467 | + - A set of all import statements required by the parsed methods. |
| 468 | + - A set of all type annotations found in the parsed methods. |
| 469 | + - A dictionary mapping request class names to their `_id` arguments. |
| 470 | + """ |
| 471 | + project_root = config["project_root"] |
| 472 | + source_patterns_dict = config.get("source_files", {}) |
| 473 | + filter_rules = config.get("filter", {}) |
| 474 | + class_filters = filter_rules.get("classes", {}) |
| 475 | + method_filters = filter_rules.get("methods", {}) |
| 476 | + |
| 477 | + source_files = [] |
| 478 | + for group in source_patterns_dict.values(): |
| 479 | + for pattern in group: |
| 480 | + # Make the pattern absolute |
| 481 | + absolute_pattern = os.path.join(project_root, pattern) |
| 482 | + source_files.extend(glob.glob(absolute_pattern, recursive=True)) |
| 483 | + |
| 484 | + # PASS 1: Build the request argument schema from the types files. |
| 485 | + request_arg_schema = _build_request_arg_schema(source_files, project_root) |
| 486 | + |
| 487 | + # PASS 2: Process the service client files. |
| 488 | + parsed_data, all_imports, all_types = _process_service_clients( |
| 489 | + source_files, class_filters, method_filters |
| 490 | + ) |
| 491 | + |
| 492 | + return parsed_data, all_imports, all_types, request_arg_schema |
0 commit comments