Skip to content

Commit bad2950

Browse files
authored
[LOG] More details log for key when TRITON_PRINT_AUTOTUNING=1 (#7017)
<!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) # More details log for key when TRITON_PRINT_AUTOTUNING=1 Fix #6636 When running auto tuner with `TRITON_PRINT_AUTOTUNING=1`, we could just see the best configuration but not **key**, some user would like to see the **key** if there is multiple combination of **key**. This PR tries to log more details about key.
1 parent e322605 commit bad2950

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/triton/runtime/autotuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def benchmark():
243243
config = self.configs[0]
244244
self.best_config = config
245245
if knobs.autotuning.print and not used_cached_result:
246-
print(f"Triton autotuning for function {self.base_fn.__name__} finished after "
247-
f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
246+
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
247+
f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
248248
if config.pre_hook is not None:
249249
full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
250250
config.pre_hook(full_nargs)

0 commit comments

Comments
 (0)