14
14
15
15
# On GPU-1, run first 1/4 of inputs for all kernels and save results to CSV in the current directory
16
16
$ CUDA_VISIBLE_DEVICES=1 python benchmarks/run.py --input-shard 1/4 --metrics accuracy,tflops,gbps,speedup --csv --output-dir ./
17
+
18
+ # Run 5 equally-spaced inputs instead of the first 5
19
+ $ python benchmarks/run.py --kernel vector_add --num-inputs 5 --input-sample-mode equal-spaced
17
20
"""
18
21
19
22
from __future__ import annotations
@@ -267,6 +270,7 @@ def run_kernel(
267
270
kernel_name : str ,
268
271
tritonbench_args : list [str ],
269
272
input_shard_info : tuple [int , int ] | None = None ,
273
+ input_sample_mode : str = "first-n" ,
270
274
) -> None :
271
275
"""Run a kernel benchmark, handling both single and multiple variants."""
272
276
# Check if kernel is in the mapping table
@@ -313,6 +317,7 @@ def run_kernel(
313
317
tritonbench_args ,
314
318
input_shard_info ,
315
319
operator_args ,
320
+ input_sample_mode ,
316
321
)
317
322
318
323
@@ -323,6 +328,7 @@ def run_kernel_variants(
323
328
tritonbench_args : list [str ],
324
329
input_shard_info : tuple [int , int ] | None = None ,
325
330
operator_args : dict [str , Any ] | None = None ,
331
+ input_sample_mode : str = "first-n" ,
326
332
) -> None :
327
333
"""Run kernel variants in the same benchmark run."""
328
334
@@ -461,18 +467,61 @@ def _inner() -> Callable[..., Any] | object:
461
467
462
468
from tritonbench .run import _run
463
469
464
- # Handle input sharding if requested
470
+ # Get the actual number of inputs for this operator
471
+ total_inputs = Operator (
472
+ tb_args = tb_args , extra_args = unknown_args
473
+ )._available_num_inputs
474
+
475
+ # First, handle input sampling based on --num-inputs and --input-sample-mode
476
+ selected_indices = None
477
+
478
+ if "--num-inputs" in tritonbench_args :
479
+ # Make a copy to avoid modifying the original list
480
+ tritonbench_args = tritonbench_args .copy ()
481
+ # Extract num-inputs value
482
+ num_inputs_idx = tritonbench_args .index ("--num-inputs" )
483
+ if num_inputs_idx + 1 < len (tritonbench_args ):
484
+ num_inputs = int (tritonbench_args [num_inputs_idx + 1 ])
485
+
486
+ if input_sample_mode == "equal-spaced" :
487
+ # Calculate equal-spaced indices
488
+ if num_inputs >= total_inputs :
489
+ # If requesting more inputs than available, just use all
490
+ selected_indices = list (range (total_inputs ))
491
+ else :
492
+ # Calculate step size for equal spacing
493
+ step = (total_inputs - 1 ) / (num_inputs - 1 ) if num_inputs > 1 else 0
494
+ selected_indices = [int (round (i * step )) for i in range (num_inputs )]
495
+
496
+ print (
497
+ f"Step 1 - Equal-spaced sampling: { num_inputs } inputs from { total_inputs } total" ,
498
+ file = sys .stderr ,
499
+ )
500
+ print (f" Selected indices: { selected_indices } " , file = sys .stderr )
501
+ else :
502
+ # first-n mode: select first N inputs
503
+ selected_indices = list (range (min (num_inputs , total_inputs )))
504
+ print (
505
+ f"Step 1 - First-n sampling: { num_inputs } inputs from { total_inputs } total" ,
506
+ file = sys .stderr ,
507
+ )
508
+ print (f" Selected indices: { selected_indices } " , file = sys .stderr )
509
+
510
+ # Remove --num-inputs from args since we'll handle it differently
511
+ tritonbench_args .pop (num_inputs_idx ) # Remove --num-inputs
512
+ tritonbench_args .pop (num_inputs_idx ) # Remove the value
513
+ else :
514
+ # No sampling requested, use all inputs
515
+ selected_indices = list (range (total_inputs ))
516
+
517
+ # Then, handle sharding if requested
465
518
if input_shard_info :
466
519
shard_idx , total_shards = input_shard_info
467
-
468
- # Get the actual number of inputs for this operator
469
- total_inputs = Operator (
470
- tb_args = tb_args , extra_args = unknown_args
471
- )._available_num_inputs
472
-
473
- # Calculate shard boundaries
474
- inputs_per_shard = total_inputs // total_shards
475
- extra_inputs = total_inputs % total_shards
520
+
521
+ # Calculate shard boundaries on the selected indices
522
+ num_selected = len (selected_indices )
523
+ inputs_per_shard = num_selected // total_shards
524
+ extra_inputs = num_selected % total_shards
476
525
477
526
if shard_idx <= extra_inputs :
478
527
start_idx = (shard_idx - 1 ) * (inputs_per_shard + 1 )
@@ -484,15 +533,21 @@ def _inner() -> Callable[..., Any] | object:
484
533
)
485
534
shard_size = inputs_per_shard
486
535
536
+ # Get the actual indices for this shard
537
+ shard_indices = selected_indices [start_idx :start_idx + shard_size ]
538
+
487
539
print (
488
- f"Running input shard { shard_idx } /{ total_shards } : inputs { start_idx } to { start_idx + shard_size - 1 } (of { total_inputs } total) " ,
540
+ f"Step 2 - Sharding: shard { shard_idx } /{ total_shards } gets { len ( shard_indices ) } inputs " ,
489
541
file = sys .stderr ,
490
542
)
543
+ print (f" Shard indices: { shard_indices } " , file = sys .stderr )
544
+
545
+ # Update selected_indices to only include this shard
546
+ selected_indices = shard_indices
491
547
492
- # Add input-id and num-inputs to the tritonbench args before re-parsing
493
- tritonbench_args .extend (
494
- ["--input-id" , str (start_idx ), "--num-inputs" , str (shard_size )]
495
- )
548
+ # Add the final selected indices to tritonbench args
549
+ if selected_indices is not None and len (selected_indices ) > 0 and len (selected_indices ) < total_inputs :
550
+ tritonbench_args .extend (["--input-id" , "," .join (map (str , selected_indices ))])
496
551
497
552
# Re-parse args with the new input range
498
553
tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
@@ -523,6 +578,13 @@ def main() -> None:
523
578
type = str ,
524
579
help = "Run only a subset of inputs for each kernel. Format: M/N where M is the shard number (1-indexed) and N is the total number of shards. For example, --input-shard 1/3 runs the first third of inputs for each kernel." ,
525
580
)
581
+ parser .add_argument (
582
+ "--input-sample-mode" ,
583
+ type = str ,
584
+ choices = ["first-n" , "equal-spaced" ],
585
+ default = "first-n" ,
586
+ help = "How to sample inputs when using --num-inputs. 'first-n' (default) takes the first X inputs. 'equal-spaced' takes X inputs equally spaced throughout the input list." ,
587
+ )
526
588
527
589
# Parse known args to get the kernel name, pass rest to tritonbench
528
590
args , tritonbench_args = parser .parse_known_args ()
@@ -568,7 +630,7 @@ def main() -> None:
568
630
569
631
# Run specified kernels
570
632
if len (kernel_names ) == 1 :
571
- run_kernel (kernel_names [0 ], tritonbench_args , input_shard_info )
633
+ run_kernel (kernel_names [0 ], tritonbench_args , input_shard_info , args . input_sample_mode )
572
634
else :
573
635
print (
574
636
f"Running { len (kernel_names )} kernels: { ', ' .join (kernel_names )} ...\n " ,
@@ -578,15 +640,15 @@ def main() -> None:
578
640
print (f"\n { '=' * 60 } " , file = sys .stderr )
579
641
print (f"Kernel: { kernel_name } " , file = sys .stderr )
580
642
print (f"{ '=' * 60 } \n " , file = sys .stderr )
581
- run_kernel (kernel_name , tritonbench_args .copy (), input_shard_info )
643
+ run_kernel (kernel_name , tritonbench_args .copy (), input_shard_info , args . input_sample_mode )
582
644
else :
583
645
# Run all kernels
584
646
print (f"Running all { len (KERNEL_MAPPINGS )} kernels...\n " , file = sys .stderr )
585
647
for kernel_name in KERNEL_MAPPINGS :
586
648
print (f"\n { '=' * 60 } " , file = sys .stderr )
587
649
print (f"Kernel: { kernel_name } " , file = sys .stderr )
588
650
print (f"{ '=' * 60 } \n " , file = sys .stderr )
589
- run_kernel (kernel_name , tritonbench_args .copy (), input_shard_info )
651
+ run_kernel (kernel_name , tritonbench_args .copy (), input_shard_info , args . input_sample_mode )
590
652
591
653
592
654
if __name__ == "__main__" :
0 commit comments