Skip to content

Commit 3d23844

Browse files
authored
Fix native code generation in test_aot.py (#4127)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b189340 commit 3d23844

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

python/test/unit/tools/test_aot.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,7 @@ def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps,
388388
str(num_warps),
389389
"-g",
390390
grid,
391-
"-gspv",
392-
str(not generate_native_code),
391+
*(["-gnc"] if generate_native_code else []),
393392
kernel_path,
394393
],
395394
check=True,

python/triton/tools/compile.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
5454
parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
5555
parser.add_argument("--grf-mode", "-gm", type=str, default="large", help="Detemine spv build flags")
56-
parser.add_argument("--generate-spv", "-gspv", type=bool, default=True, help="Cache SPV or native binary for XPU")
56+
parser.add_argument("--generate-native-code", "-gnc", action="store_true",
57+
help="Generate native binary instead of SPV for XPU")
5758
args = parser.parse_args()
5859

5960
out_name = args.out_name if args.out_name else args.kernel_name
@@ -115,7 +116,7 @@ def constexpr(s):
115116
if is_xpu():
116117
opts = {
117118
"num_warps": args.num_warps, "num_stages": args.num_stages, "threads_per_warp": args.threads_per_warp,
118-
"grf_mode": args.grf_mode, "generate_native_code": not args.generate_spv
119+
"grf_mode": args.grf_mode, "generate_native_code": args.generate_native_code
119120
}
120121
ccinfo = triton.compile(src, options=opts)
121122
if is_cuda():
@@ -195,7 +196,7 @@ def constexpr(s):
195196
"gridX": grid[0],
196197
"gridY": grid[1],
197198
"gridZ": grid[2],
198-
"is_spv": "true" if args.generate_spv else "false",
199+
"is_spv": "false" if args.generate_native_code else "true",
199200
"_placeholder": "",
200201
}
201202
for ext in ['h', 'cpp']:

0 commit comments

Comments
 (0)