|
53 | 53 | parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) |
54 | 54 | parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) |
55 | 55 | parser.add_argument("--grf-mode", "-gm", type=str, default="large", help="Detemine spv build flags") |
56 | | - parser.add_argument("--generate-spv", "-gspv", type=bool, default=True, help="Cache SPV or native binary for XPU") |
| 56 | + parser.add_argument("--generate-native-code", "-gnc", action="store_true", |
| 57 | + help="Generate native binary instead of SPV for XPU") |
57 | 58 | args = parser.parse_args() |
58 | 59 |
|
59 | 60 | out_name = args.out_name if args.out_name else args.kernel_name |
@@ -115,7 +116,7 @@ def constexpr(s): |
115 | 116 | if is_xpu(): |
116 | 117 | opts = { |
117 | 118 | "num_warps": args.num_warps, "num_stages": args.num_stages, "threads_per_warp": args.threads_per_warp, |
118 | | - "grf_mode": args.grf_mode, "generate_native_code": not args.generate_spv |
| 119 | + "grf_mode": args.grf_mode, "generate_native_code": args.generate_native_code |
119 | 120 | } |
120 | 121 | ccinfo = triton.compile(src, options=opts) |
121 | 122 | if is_cuda(): |
@@ -195,7 +196,7 @@ def constexpr(s): |
195 | 196 | "gridX": grid[0], |
196 | 197 | "gridY": grid[1], |
197 | 198 | "gridZ": grid[2], |
198 | | - "is_spv": "true" if args.generate_spv else "false", |
| 199 | + "is_spv": "false" if args.generate_native_code else "true", |
199 | 200 | "_placeholder": "", |
200 | 201 | } |
201 | 202 | for ext in ['h', 'cpp']: |
|
0 commit comments