1+ # --------------------------------------------------------------------------------------------
2+ # Copyright (c) Microsoft Corporation. All rights reserved.
3+ # Licensed under the MIT License. See License.txt in the project root for license information.
4+ # --------------------------------------------------------------------------------------------
5+
6+ from pathlib import Path
7+ import os
8+ import asyncio
9+ import random
10+ from fastmcp import Context
11+ from prompt_templates import get_testgen_static_instructions , REF_STYLE_LABEL
12+
13+
14+ async def check_module_status (ctx : Context , paths : dict ):
15+ await ctx .info ("Starting test generation workflow." )
16+
17+ module_name = getattr (ctx , "generated_module" , None )
18+ if not module_name :
19+ response = await ctx .elicit ("Enter the module/extension name to generate tests for:" )
20+ if response .action != "accept" or not response .data :
21+ return "Test generation cancelled."
22+ module_name = response .data
23+ else :
24+ await ctx .info (f"Detected generated module: { module_name } " )
25+
26+ ctx .generated_module = module_name
27+ return module_name
28+
29+ def find_test_dir (cli_extension_path : str , module_name : str ) -> Path | None :
30+ base = Path (cli_extension_path ) / "src" / module_name
31+ for path in base .rglob ("tests/latest" ):
32+ if path .is_dir ():
33+ return path
34+ return None
35+
36+ async def check_path_status (ctx : Context , paths : dict ):
37+ module_name = await check_module_status (ctx , paths )
38+ if not module_name or "cancelled" in str (module_name ).lower ():
39+ return str (module_name ), [], None
40+
41+ aaz_path = Path (
42+ f"{ paths ['cli_extension' ]} /src/{ module_name } " )
43+ if not aaz_path .exists ():
44+ return f"AAZ path not found for module '{ module_name } '" , [], None
45+
46+ commands = []
47+ for file in aaz_path .rglob ("*.py" ):
48+ with open (file , "r" , encoding = "utf-8" ) as f :
49+ for line in f :
50+ if line .strip ().startswith ("def " ):
51+ commands .append (line .strip ().replace (
52+ "def " , "" ).split ("(" )[0 ])
53+
54+ test_dir = find_test_dir (paths ["cli_extension" ], module_name )
55+ test_dir .mkdir (parents = True , exist_ok = True )
56+ test_file = test_dir / f"test_{ module_name } .py"
57+ return module_name , commands , test_file
58+
59+ def build_testgen_prompt (module_name : str , commands : list [str ], reference_snippet : str = "" ) -> str :
60+ parts = [
61+ get_testgen_static_instructions (),
62+ (
63+ f"Module name: '{ module_name } '. Generate a single test class named "
64+ f"'{ module_name .capitalize ()} ScenarioTest' deriving from ScenarioTest."
65+ ),
66+ "Discovered AAZ functions (potential commands):\n " +
67+ ", " .join (commands [:30 ])
68+ ]
69+ if reference_snippet :
70+ parts .append (REF_STYLE_LABEL + reference_snippet )
71+ return "\n \n " .join (parts )
72+
73+ def strip_code_fences (text : str ) -> str :
74+ lines = text .strip ().splitlines ()
75+ blocks = []
76+ inside = False
77+ current = []
78+
79+ for line in lines :
80+ if line .strip ().startswith ("```" ):
81+ if inside :
82+ blocks .append ("\n " .join (current ).strip ())
83+ current = []
84+ inside = False
85+ else :
86+ inside = True
87+ elif inside :
88+ current .append (line )
89+ if current :
90+ blocks .append ("\n " .join (current ).strip ())
91+
92+ if blocks :
93+ return max (blocks , key = len )
94+ return text .strip ()
95+
96+
97+ async def generate_tests (ctx : Context , paths : dict ):
98+ module_name , commands , test_file = await check_path_status (ctx , paths )
99+
100+ if test_file is None and not commands :
101+ return str (module_name )
102+
103+ if not commands :
104+ return f"No commands found to generate tests for module '{ module_name } '."
105+
106+ reference_snippet = "\n " .join ([
107+ "azure-cli/src/azure-cli/azure/cli/command_modules/resource/tests/latest/test_resource.py" ,
108+ "azure-cli/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py"
109+ ])
110+
111+ sampling_prompt = build_testgen_prompt (
112+ module_name , commands , reference_snippet )
113+ ctx .info ("Constructed test generation prompt as follows:\n " + sampling_prompt )
114+ max_retries = int (os .getenv ("TESTGEN_RETRIES" , "5" ))
115+ base_delay = float (os .getenv ("TESTGEN_RETRY_BASE_DELAY" , "2" ))
116+
117+ attempt = 0
118+ content = ""
119+ last_err = None
120+ while attempt <= max_retries :
121+ try :
122+ if attempt > 0 :
123+ await ctx .info (f"Retrying test generation (attempt { attempt } /{ max_retries } )..." )
124+ sampled = await ctx .sample (sampling_prompt )
125+ raw_content = (getattr (sampled , "text" , "" ) or "" ).strip ()
126+ content = strip_code_fences (raw_content )
127+ if content :
128+ break
129+ last_err = RuntimeError ("Empty content returned from provider" )
130+ raise last_err
131+ except Exception as ex :
132+ last_err = ex
133+ message = str (ex ).lower ()
134+ retriable = any (k in message for k in [
135+ "rate limit" , "overloaded" , "timeout" , "temporarily unavailable" , "429"
136+ ])
137+ if attempt >= max_retries or not retriable :
138+ break
139+ delay = base_delay * (2 ** attempt )
140+ delay = min (delay , 30 )
141+ jitter = random .uniform (0.7 , 1.3 )
142+ wait_time = delay * jitter
143+ await ctx .info (f"Transient error encountered: { ex } . Waiting { wait_time :.1f} s before retry." )
144+ await asyncio .sleep (wait_time )
145+ attempt += 1
146+ continue
147+
148+ if not content :
149+ if last_err :
150+ return (
151+ f"Test generation failed after { max_retries } retries for module '{ module_name } ': "
152+ f"{ last_err } "
153+ )
154+ return f"Test generation failed: no content generated for module '{ module_name } '."
155+
156+ with open (test_file , "w" , encoding = "utf-8" ) as f :
157+ f .write (content )
158+
159+ await ctx .info (f"Generated test file: { test_file } " )
160+ return f"Test generation completed for module '{ module_name } '."
0 commit comments