@@ -205,7 +205,11 @@ def check_and_setup_tritonbench() -> None:
205
205
sys .exit (1 )
206
206
207
207
208
- def run_kernel (kernel_name : str , tritonbench_args : list [str ]) -> None :
208
+ def run_kernel (
209
+ kernel_name : str ,
210
+ tritonbench_args : list [str ],
211
+ input_shard_info : tuple [int , int ] | None = None ,
212
+ ) -> None :
209
213
"""Run a single kernel benchmark."""
210
214
# Check if kernel is in the mapping table
211
215
if kernel_name not in KERNEL_MAPPINGS :
@@ -343,6 +347,36 @@ def _inner() -> Callable[..., Any] | object:
343
347
# Create and run the operator with unknown args
344
348
op = Operator (tb_args = tb_args , extra_args = unknown_args )
345
349
350
+ # Handle input sharding if requested
351
+ if input_shard_info :
352
+ shard_idx , total_shards = input_shard_info
353
+
354
+ # Get the actual number of inputs for this operator
355
+ total_inputs = op ._available_num_inputs
356
+
357
+ # Calculate shard boundaries
358
+ inputs_per_shard = total_inputs // total_shards
359
+ extra_inputs = total_inputs % total_shards
360
+
361
+ if shard_idx <= extra_inputs :
362
+ start_idx = (shard_idx - 1 ) * (inputs_per_shard + 1 )
363
+ shard_size = inputs_per_shard + 1
364
+ else :
365
+ start_idx = (
366
+ extra_inputs * (inputs_per_shard + 1 )
367
+ + (shard_idx - 1 - extra_inputs ) * inputs_per_shard
368
+ )
369
+ shard_size = inputs_per_shard
370
+
371
+ # Override the operator's input range
372
+ op ._input_id = start_idx
373
+ op ._num_inputs = shard_size
374
+
375
+ print (
376
+ f"Running input shard { shard_idx } /{ total_shards } : inputs { start_idx } to { start_idx + shard_size - 1 } (of { total_inputs } total)" ,
377
+ file = sys .stderr ,
378
+ )
379
+
346
380
# Run with proper parameters
347
381
warmup = int (getattr (tb_args , "warmup" , 25 ))
348
382
rep = int (getattr (tb_args , "iter" , 100 ))
@@ -369,13 +403,37 @@ def main() -> None:
369
403
type = str ,
370
404
help = "Name(s) of the Helion kernel module(s) to run. Can be a single kernel or comma-separated list (e.g., vector_add or vector_add,rms_norm). If not specified, runs all kernels." ,
371
405
)
406
+ parser .add_argument (
407
+ "--input-shard" ,
408
+ type = str ,
409
+ help = "Run only a subset of inputs for each kernel. Format: M/N where M is the shard number (1-indexed) and N is the total number of shards. For example, --input-shard 1/3 runs the first third of inputs for each kernel." ,
410
+ )
372
411
373
412
# Parse known args to get the kernel name, pass rest to tritonbench
374
413
args , tritonbench_args = parser .parse_known_args ()
375
414
376
415
# Check and setup tritonbench if needed
377
416
check_and_setup_tritonbench ()
378
417
418
+ # Store input-shard info for later processing
419
+ input_shard_info = None
420
+ if args .input_shard :
421
+ try :
422
+ shard_idx , total_shards = map (int , args .input_shard .split ("/" ))
423
+ if shard_idx < 1 or shard_idx > total_shards :
424
+ print (
425
+ f"Error: Shard number { shard_idx } must be between 1 and { total_shards } " ,
426
+ file = sys .stderr ,
427
+ )
428
+ sys .exit (1 )
429
+ input_shard_info = (shard_idx , total_shards )
430
+ except ValueError :
431
+ print (
432
+ f"Error: Invalid input-shard format '{ args .input_shard } '. Expected format: M/N (e.g., 1/3)" ,
433
+ file = sys .stderr ,
434
+ )
435
+ sys .exit (1 )
436
+
379
437
if args .kernel :
380
438
# Parse comma-separated kernel names
381
439
kernel_names = [k .strip () for k in args .kernel .split ("," )]
@@ -395,7 +453,7 @@ def main() -> None:
395
453
396
454
# Run specified kernels
397
455
if len (kernel_names ) == 1 :
398
- run_kernel (kernel_names [0 ], tritonbench_args )
456
+ run_kernel (kernel_names [0 ], tritonbench_args , input_shard_info )
399
457
else :
400
458
print (
401
459
f"Running { len (kernel_names )} kernels: { ', ' .join (kernel_names )} ...\n " ,
@@ -405,15 +463,15 @@ def main() -> None:
405
463
print (f"\n { '=' * 60 } " , file = sys .stderr )
406
464
print (f"Kernel: { kernel_name } " , file = sys .stderr )
407
465
print (f"{ '=' * 60 } \n " , file = sys .stderr )
408
- run_kernel (kernel_name , tritonbench_args .copy ())
466
+ run_kernel (kernel_name , tritonbench_args .copy (), input_shard_info )
409
467
else :
410
468
# Run all kernels
411
469
print (f"Running all { len (KERNEL_MAPPINGS )} kernels...\n " , file = sys .stderr )
412
470
for kernel_name in KERNEL_MAPPINGS :
413
471
print (f"\n { '=' * 60 } " , file = sys .stderr )
414
472
print (f"Kernel: { kernel_name } " , file = sys .stderr )
415
473
print (f"{ '=' * 60 } \n " , file = sys .stderr )
416
- run_kernel (kernel_name , tritonbench_args .copy ())
474
+ run_kernel (kernel_name , tritonbench_args .copy (), input_shard_info )
417
475
418
476
419
477
if __name__ == "__main__" :
0 commit comments