Skip to content

Commit 4dd5db1

Browse files
committed
style: add dsl config
1 parent 47bab0a commit 4dd5db1

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

graph_net/config/config_agent_backend.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ max_tokens: 16384
55
parallel_query_nums: 1
66
iterative_query_nums: 2
77

8+
# support "CUDA", "Triton", future will support "TileLang", etc.
9+
DSL: "CUDA"
10+
811
# responses and logs will be saved in <top_save_dir>/<llm_name>
912
top_save_dir: "./tmp/llm_cache"
1013

graph_net/torch/backend/agent_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def compile(self, *args, **kwargs):
1919
self.module,
2020
model_inputs=dummy_input,
2121
task_name=f"default_task_{self.count_compile}",
22-
language="cuda",
2322
)
2423
self.count_compile += 1
2524
return optimized_module

graph_net/torch/backend/agent_ncu/agent_compile.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,19 @@ def generate(prompt, system_prompt, llm_query_config: LLMQueryConfig):
3535
raise RuntimeError(f"LLM query failed with exception: {e}")
3636

3737

38-
def optimize(
39-
module, model_inputs=None, language: str = "cuda", task_name: str = "default_task"
40-
):
38+
def optimize(module, model_inputs=None, task_name: str = "default_task"):
4139
"""Optimize the given PyTorch module using custom DSL operators."""
4240

4341
llm_config = get_llm_config()
4442
llm_query_config = LLMQueryConfig(**llm_config)
4543
traced_module = torch.fx.symbolic_trace(module)
4644

47-
if "cuda" == language:
45+
if "cuda" == llm_config.dsl.lower():
4846
return cuda_optimize(traced_module, model_inputs, task_name, llm_query_config)
49-
elif "triton" == language:
47+
elif "triton" == llm_config.dsl.lower():
5048
return torch.compile(module) # TODO add custom triton optimize
5149
else:
52-
raise NotImplementedError(f"Unsupported language: {language}")
50+
raise NotImplementedError(f"Unsupported language: {llm_config.dsl}")
5351

5452
# return the best of optimized models
5553

graph_net/torch/backend/agent_utils/query_llm_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class LLMQueryConfig:
1717
parallel_query_nums: int = 1
1818
iterative_query_nums: int = 10
1919

20+
dsl: str = "CUDA"
21+
2022
# cache settings
2123
# responses will be saved in <tmp_llm_cache>/<llm_name>/
2224
top_save_dir: str = "./tmp_llm_cache"

0 commit comments

Comments
 (0)