Skip to content

Commit 745a3f7

Browse files
committed
test gen for async too
apply testgen-async fix bug when iterating over star imports fix cst * import errors
1 parent bafad49 commit 745a3f7

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

codeflash/api/aiservice.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,16 @@ def generate_regression_tests( # noqa: D417
439439
"test_index": test_index,
440440
"python_version": platform.python_version(),
441441
"codeflash_version": codeflash_version,
442+
"is_async": bool(getattr(function_to_optimize, "is_async", False)),
442443
}
444+
445+
endpoint = "/testgen"
446+
logger.debug(
447+
f"Using unified test generation endpoint for function {function_to_optimize.function_name} (is_async={payload['is_async']})"
448+
)
449+
443450
try:
444-
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
451+
response = self.make_ai_service_request(endpoint, payload=payload, timeout=600)
445452
except requests.exceptions.RequestException as e:
446453
logger.exception(f"Error generating tests: {e}")
447454
ph("cli-testgen-error-caught", {"error": str(e)})

codeflash/code_utils/code_extractor.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
272272
if child.module is None:
273273
continue
274274
module = self.get_full_dotted_name(child.module)
275+
if isinstance(child.names, cst.ImportStar):
276+
continue
275277
for alias in child.names:
276278
if isinstance(alias, cst.ImportAlias):
277279
name = alias.name.value
@@ -403,6 +405,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
403405
return transformed_module.code
404406

405407

408+
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
409+
try:
410+
module_path = module_name.replace(".", "/")
411+
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
412+
413+
module_file = None
414+
for path in possible_paths:
415+
if path.exists():
416+
module_file = path
417+
break
418+
419+
if module_file is None:
420+
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
421+
return set()
422+
423+
with module_file.open(encoding="utf8") as f:
424+
module_code = f.read()
425+
426+
tree = ast.parse(module_code)
427+
428+
all_names = None
429+
for node in ast.walk(tree):
430+
if (
431+
isinstance(node, ast.Assign)
432+
and len(node.targets) == 1
433+
and isinstance(node.targets[0], ast.Name)
434+
and node.targets[0].id == "__all__"
435+
):
436+
if isinstance(node.value, (ast.List, ast.Tuple)):
437+
all_names = []
438+
for elt in node.value.elts:
439+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
440+
all_names.append(elt.value)
441+
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
442+
all_names.append(elt.s)
443+
break
444+
445+
if all_names is not None:
446+
return set(all_names)
447+
else:
448+
public_names = set()
449+
for node in tree.body:
450+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
451+
if not node.name.startswith("_"):
452+
public_names.add(node.name)
453+
elif isinstance(node, ast.Assign):
454+
for target in node.targets:
455+
if isinstance(target, ast.Name) and not target.id.startswith("_"):
456+
public_names.add(target.id)
457+
elif isinstance(node, ast.AnnAssign):
458+
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
459+
public_names.add(node.target.id)
460+
elif isinstance(node, ast.Import) or (
461+
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
462+
):
463+
for alias in node.names:
464+
name = alias.asname or alias.name
465+
if not name.startswith("_"):
466+
public_names.add(name)
467+
468+
return public_names
469+
470+
except Exception as e:
471+
logger.warning(f"Error resolving star import for {module_name}: {e}")
472+
return set()
473+
474+
406475
def add_needed_imports_from_module(
407476
src_module_code: str,
408477
dst_module_code: str,
@@ -457,9 +526,23 @@ def add_needed_imports_from_module(
457526
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
458527
):
459528
continue # Skip adding imports for helper functions already in the context
460-
if f"{mod}.{obj}" not in dotted_import_collector.imports:
461-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
462-
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
529+
530+
# Handle star imports by resolving them to actual symbol names
531+
if obj == "*":
532+
resolved_symbols = resolve_star_import(mod, project_root)
533+
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
534+
535+
for symbol in resolved_symbols:
536+
if (
537+
f"{mod}.{symbol}" not in helper_functions_fqn
538+
and f"{mod}.{symbol}" not in dotted_import_collector.imports
539+
):
540+
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
541+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
542+
else:
543+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
544+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
545+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
463546
except Exception as e:
464547
logger.exception(f"Error adding imports to destination module code: {e}")
465548
return dst_module_code

0 commit comments

Comments
 (0)