Skip to content

Commit cfb7ebf

Browse files
committed
[New Feature] CUDA agent compiler backend
1 parent 9a55a74 commit cfb7ebf

File tree

11 files changed

+1193
-0
lines changed

11 files changed

+1193
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
llm_name: "gpt-5"
2+
temperature: 1.0
3+
top_p: 1.0
4+
max_tokens: 16384
5+
parallel_query_nums: 1
6+
iterative_query_nums: 2
7+
8+
# responses and logs will be saved in <top_save_dir>/<llm_name>
9+
top_save_dir: "./tmp/llm_cache"
10+
11+
# network settings
12+
timeout_seconds: 600.0
13+
max_retries: 3
14+
backoff_initial_seconds: 1.0
15+
backoff_max_seconds: 16.0
16+
17+
# local Inference settings
18+
is_local_inference: False
19+
local_inference_vendor: "vllm" # only for print logs
20+
local_inference_port: 30000
21+
local_inference_address: "localhost"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import os
2+
import yaml
3+
4+
5+
def get_llm_config() -> dict:
6+
api_key = os.getenv("GRAPHNET_AGENT_API_KEY", None)
7+
base_url = os.getenv("GRAPHNET_AGENT_BASE_URL", None)
8+
config_dir = os.path.dirname(os.path.abspath(__file__))
9+
llm_config_path = os.path.join(config_dir, "config_agent_backend.yaml")
10+
with open(llm_config_path, "r") as file:
11+
llm_config = yaml.safe_load(file)
12+
llm_config["api_key"] = api_key
13+
llm_config["api_url"] = base_url
14+
return llm_config
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend
3+
from graph_net.torch.backend.agent_ncu import agent_compile
4+
5+
6+
class AgentCompiledModule(torch.nn.Module):
7+
def __init__(self, module):
8+
super().__init__()
9+
self.module = module
10+
self.count_compile = 0
11+
12+
def forward(self, *args, **kwargs):
13+
self.module = self.compile(*args, **kwargs)
14+
return self.module
15+
16+
def compile(self, *args, **kwargs):
17+
dummy_input = tuple([*args, *kwargs.values()])
18+
optimized_module = agent_compile.optimize(
19+
self.module,
20+
model_inputs=dummy_input,
21+
task_name=f"default_task_{self.count_compile}",
22+
language="cuda",
23+
)
24+
self.count_compile += 1
25+
return optimized_module
26+
27+
28+
class AgentBackend(GraphCompilerBackend):
29+
def __call__(self, model):
30+
return AgentCompiledModule(model)
31+
32+
def synchronize(self):
33+
if torch.cuda.is_available():
34+
torch.cuda.synchronize()

graph_net/torch/backend/agent_ncu/__init__.py

Whitespace-only changes.
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import os
2+
import torch
3+
import warnings
4+
from graph_net.config.fetch_agent_config import get_llm_config
5+
from graph_net.torch.backend.agent_utils.query_llm_utils import (
6+
LLMQueryConfig,
7+
query_llm_service,
8+
add_token_usage,
9+
)
10+
from .prompt_cuda import (
11+
generate_default_cuda_prompt,
12+
judge_optimize_prompt,
13+
judge_correct_prompt,
14+
coder_optimize_prompt,
15+
coder_correct_prompt,
16+
)
17+
from .kernel_text_util import (
18+
extract_cuda_code,
19+
remove_pybind_module,
20+
compile_kernel,
21+
exec_eval_cuda,
22+
exec_eval_cuda_with_ncu,
23+
format_with_kernelbench_style,
24+
)
25+
from .prompt_cuda import PROMPT_SYSTEM
26+
27+
28+
def generate(prompt, system_prompt, llm_query_config: LLMQueryConfig):
29+
try:
30+
query_result = query_llm_service(
31+
prompt=prompt, system_prompt=system_prompt, query_config=llm_query_config
32+
)
33+
return query_result
34+
except Exception as e:
35+
raise RuntimeError(f"LLM query failed with exception: {e}")
36+
37+
38+
def optimize(
39+
module, model_inputs=None, language: str = "cuda", task_name: str = "default_task"
40+
):
41+
"""Optimize the given PyTorch module using custom DSL operators."""
42+
43+
llm_config = get_llm_config()
44+
llm_query_config = LLMQueryConfig(**llm_config)
45+
traced_module = torch.fx.symbolic_trace(module)
46+
47+
if "cuda" == language:
48+
return cuda_optimize(traced_module, model_inputs, task_name, llm_query_config)
49+
elif "triton" == language:
50+
return torch.compile(module) # TODO add custom triton optimize
51+
else:
52+
raise NotImplementedError(f"Unsupported language: {language}")
53+
54+
# return the best of optimized models
55+
56+
57+
def agent_fix_cuda_error(
58+
error_info, torch_model_code, cuda_code, work_dir, llm_query_config
59+
):
60+
"""Fix CUDA code based on error information from compilation or execution."""
61+
62+
prompt = judge_correct_prompt(error_info, torch_model_code, cuda_code)
63+
modify_text, tokens_judge_correct = generate(
64+
prompt, PROMPT_SYSTEM, llm_query_config
65+
)
66+
prompt = coder_correct_prompt(error_info, cuda_code, modify_text)
67+
with open(os.path.join(work_dir, "fix_cuda_error_prompt.txt"), "a") as f:
68+
f.write(prompt)
69+
cuda_code, tokens_cuda_code_fix = generate(prompt, PROMPT_SYSTEM, llm_query_config)
70+
cuda_code = extract_cuda_code(cuda_code)
71+
cuda_code = remove_pybind_module(cuda_code)
72+
return cuda_code, add_token_usage(tokens_judge_correct, tokens_cuda_code_fix)
73+
74+
75+
def cuda_optimize(
76+
gm,
77+
model_inputs,
78+
task_name: str = "default_task",
79+
llm_query_config: LLMQueryConfig = None,
80+
):
81+
best_model = gm
82+
max_iters = llm_query_config.iterative_query_nums
83+
store_dir = os.path.join(llm_query_config.top_save_dir, task_name)
84+
85+
torch_model_code = format_with_kernelbench_style(gm, model_inputs)
86+
87+
# iterative optimization
88+
cur_iter_token_usage = None
89+
for iter in range(max_iters):
90+
print(f"=== Optimize {task_name}, Iteration {iter} ===", flush=True)
91+
92+
context_dir_path = os.path.join(store_dir, f"iter_{iter}")
93+
already_done: bool = os.path.exists(
94+
os.path.join(context_dir_path, "model_new.py")
95+
)
96+
97+
# Skip already done iterations
98+
if already_done:
99+
print(f"Iteration {iter} already done, skipping...", flush=True)
100+
with open(os.path.join(context_dir_path, "model_new.py"), "r") as fin:
101+
cuda_code = fin.read()
102+
continue
103+
os.makedirs(context_dir_path, exist_ok=True)
104+
105+
is_success_compilable = False
106+
is_success_functional = False
107+
108+
# Generate initial kernel code
109+
if iter == 0:
110+
prompt = generate_default_cuda_prompt(torch_model_code)
111+
text_response, cur_iter_token_usage = generate(
112+
prompt, PROMPT_SYSTEM, llm_query_config
113+
)
114+
raw_cuda_code = extract_cuda_code(text_response)
115+
cuda_code = remove_pybind_module(raw_cuda_code)
116+
117+
# compile
118+
try:
119+
is_success_compilable, compile_info = compile_kernel(
120+
cuda_code=cuda_code, work_dir=context_dir_path
121+
)
122+
123+
with open(os.path.join(context_dir_path, "log.log"), "a") as f:
124+
f.write(
125+
f"[Token Usage] Iteration {iter} cost: {cur_iter_token_usage}\n"
126+
)
127+
cur_iter_token_usage = None
128+
129+
except Exception as e: # wrapper compile_kernel to catch all exceptions
130+
print(f"Compilation failed with exception: {e}")
131+
continue
132+
133+
# [Eval Result] compile failed
134+
if not is_success_compilable:
135+
cuda_code, fix_error_token = agent_fix_cuda_error(
136+
compile_info["msg"][:4096], # save tokens
137+
torch_model_code,
138+
cuda_code,
139+
work_dir=context_dir_path,
140+
llm_query_config=llm_query_config,
141+
)
142+
cur_iter_token_usage = add_token_usage(
143+
cur_iter_token_usage, fix_error_token
144+
)
145+
continue
146+
147+
# [Result-Compile] compile success
148+
try:
149+
is_success_functional, eval_msg = exec_eval_cuda(
150+
compile_info["exec_filename"], # .so filename
151+
compile_info["exec_content"], # .so binary content
152+
torch_model_code,
153+
work_dir=context_dir_path,
154+
)
155+
except Exception as e:
156+
print(f"Execution failed with exception: {e}", flush=True)
157+
continue
158+
159+
# [Result-Execute] functional failed
160+
if not is_success_functional:
161+
cuda_code, fix_error_token = agent_fix_cuda_error(
162+
eval_msg[:4096],
163+
torch_model_code,
164+
cuda_code,
165+
work_dir=context_dir_path,
166+
llm_query_config=llm_query_config,
167+
)
168+
cur_iter_token_usage = add_token_usage(
169+
cur_iter_token_usage, fix_error_token
170+
)
171+
continue
172+
173+
# [Result-Execute] functional success: optimization with NCU analysis
174+
else:
175+
is_ncu_success, ncu_metric_info = exec_eval_cuda_with_ncu(
176+
compile_info["exec_filename"],
177+
compile_info["exec_content"],
178+
work_dir=context_dir_path,
179+
)
180+
if not is_ncu_success:
181+
warnings.warn("NCU analysis failed.", RuntimeWarning)
182+
continue
183+
optimize_prompt = judge_optimize_prompt(
184+
torch_model_code, cuda_code, ncu_metric_info
185+
)
186+
optimize_strategy, strategy_token = generate(
187+
optimize_prompt, PROMPT_SYSTEM, llm_query_config
188+
)
189+
optimize_prompt = coder_optimize_prompt(cuda_code, optimize_strategy)
190+
cuda_code, cuda_gen_token = generate(
191+
optimize_prompt, PROMPT_SYSTEM, llm_query_config
192+
)
193+
cuda_code = extract_cuda_code(cuda_code)
194+
cuda_code = remove_pybind_module(cuda_code)
195+
cur_iter_token_usage = add_token_usage(cur_iter_token_usage, strategy_token)
196+
cur_iter_token_usage = add_token_usage(cur_iter_token_usage, cuda_gen_token)
197+
198+
return best_model

0 commit comments

Comments
 (0)