File tree Expand file tree Collapse file tree 1 file changed +11
-14
lines changed Expand file tree Collapse file tree 1 file changed +11
-14
lines changed Original file line number Diff line number Diff line change @@ -2349,21 +2349,18 @@ def total_numel() -> int:
23492349 check_max_block (cfg )
23502350 check_config (cfg , xnumel = size_hints ["x" ])
23512351
2352+ config = InductorConfig (
2353+ cfg ,
2354+ num_warps = num_warps ,
2355+ num_stages = num_stages ,
2356+ dynamic_scale_rblock = dynamic_scale_rblock ,
2357+ )
2358+
23522359 if torch .version .hip :
2353- return InductorConfig (
2354- cfg ,
2355- num_warps = num_warps ,
2356- num_stages = num_stages ,
2357- waves_per_eu = waves_per_eu ,
2358- dynamic_scale_rblock = dynamic_scale_rblock ,
2359- )
2360- else :
2361- return InductorConfig (
2362- cfg ,
2363- num_warps = num_warps ,
2364- num_stages = num_stages ,
2365- dynamic_scale_rblock = dynamic_scale_rblock ,
2366- )
2360+ if waves_per_eu is not None :
2361+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2362+
2363+ return config
23672364
23682365
23692366def _get_config (numels : dict [str , int ]) -> dict [str , int ]:
You can’t perform that action at this time.
0 commit comments