Skip to content

Commit d09906a

Browse files
authored
Fix issue in merge conflict
1 parent 7e4a926 commit d09906a

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,21 +2349,18 @@ def total_numel() -> int:
23492349
check_max_block(cfg)
23502350
check_config(cfg, xnumel=size_hints["x"])
23512351

2352+
config = InductorConfig(
2353+
cfg,
2354+
num_warps=num_warps,
2355+
num_stages=num_stages,
2356+
dynamic_scale_rblock=dynamic_scale_rblock,
2357+
)
2358+
23522359
if torch.version.hip:
2353-
return InductorConfig(
2354-
cfg,
2355-
num_warps=num_warps,
2356-
num_stages=num_stages,
2357-
waves_per_eu=waves_per_eu,
2358-
dynamic_scale_rblock=dynamic_scale_rblock,
2359-
)
2360-
else:
2361-
return InductorConfig(
2362-
cfg,
2363-
num_warps=num_warps,
2364-
num_stages=num_stages,
2365-
dynamic_scale_rblock=dynamic_scale_rblock,
2366-
)
2360+
if waves_per_eu is not None:
2361+
config.kwargs["waves_per_eu"] = waves_per_eu
2362+
2363+
return config
23672364

23682365

23692366
def _get_config(numels: dict[str, int]) -> dict[str, int]:

0 commit comments

Comments
 (0)