Skip to content

Commit e01b2d0

Browse files
committed
add equal-spaced mode
1 parent 27fea44 commit e01b2d0

File tree

1 file changed

+80
-18
lines changed

1 file changed

+80
-18
lines changed

benchmarks/run.py

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
1515
# On GPU-1, run first 1/4 of inputs for all kernels and save results to CSV in the current directory
1616
$ CUDA_VISIBLE_DEVICES=1 python benchmarks/run.py --input-shard 1/4 --metrics accuracy,tflops,gbps,speedup --csv --output-dir ./
17+
18+
# Run 5 equally-spaced inputs instead of the first 5
19+
$ python benchmarks/run.py --kernel vector_add --num-inputs 5 --input-sample-mode equal-spaced
1720
"""
1821

1922
from __future__ import annotations
@@ -267,6 +270,7 @@ def run_kernel(
267270
kernel_name: str,
268271
tritonbench_args: list[str],
269272
input_shard_info: tuple[int, int] | None = None,
273+
input_sample_mode: str = "first-n",
270274
) -> None:
271275
"""Run a kernel benchmark, handling both single and multiple variants."""
272276
# Check if kernel is in the mapping table
@@ -313,6 +317,7 @@ def run_kernel(
313317
tritonbench_args,
314318
input_shard_info,
315319
operator_args,
320+
input_sample_mode,
316321
)
317322

318323

@@ -323,6 +328,7 @@ def run_kernel_variants(
323328
tritonbench_args: list[str],
324329
input_shard_info: tuple[int, int] | None = None,
325330
operator_args: dict[str, Any] | None = None,
331+
input_sample_mode: str = "first-n",
326332
) -> None:
327333
"""Run kernel variants in the same benchmark run."""
328334

@@ -461,18 +467,61 @@ def _inner() -> Callable[..., Any] | object:
461467

462468
from tritonbench.run import _run
463469

464-
# Handle input sharding if requested
470+
# Get the actual number of inputs for this operator
471+
total_inputs = Operator(
472+
tb_args=tb_args, extra_args=unknown_args
473+
)._available_num_inputs
474+
475+
# First, handle input sampling based on --num-inputs and --input-sample-mode
476+
selected_indices = None
477+
478+
if "--num-inputs" in tritonbench_args:
479+
# Make a copy to avoid modifying the original list
480+
tritonbench_args = tritonbench_args.copy()
481+
# Extract num-inputs value
482+
num_inputs_idx = tritonbench_args.index("--num-inputs")
483+
if num_inputs_idx + 1 < len(tritonbench_args):
484+
num_inputs = int(tritonbench_args[num_inputs_idx + 1])
485+
486+
if input_sample_mode == "equal-spaced":
487+
# Calculate equal-spaced indices
488+
if num_inputs >= total_inputs:
489+
# If requesting more inputs than available, just use all
490+
selected_indices = list(range(total_inputs))
491+
else:
492+
# Calculate step size for equal spacing
493+
step = (total_inputs - 1) / (num_inputs - 1) if num_inputs > 1 else 0
494+
selected_indices = [int(round(i * step)) for i in range(num_inputs)]
495+
496+
print(
497+
f"Step 1 - Equal-spaced sampling: {num_inputs} inputs from {total_inputs} total",
498+
file=sys.stderr,
499+
)
500+
print(f" Selected indices: {selected_indices}", file=sys.stderr)
501+
else:
502+
# first-n mode: select first N inputs
503+
selected_indices = list(range(min(num_inputs, total_inputs)))
504+
print(
505+
f"Step 1 - First-n sampling: {num_inputs} inputs from {total_inputs} total",
506+
file=sys.stderr,
507+
)
508+
print(f" Selected indices: {selected_indices}", file=sys.stderr)
509+
510+
# Remove --num-inputs from args since we'll handle it differently
511+
tritonbench_args.pop(num_inputs_idx) # Remove --num-inputs
512+
tritonbench_args.pop(num_inputs_idx) # Remove the value
513+
else:
514+
# No sampling requested, use all inputs
515+
selected_indices = list(range(total_inputs))
516+
517+
# Then, handle sharding if requested
465518
if input_shard_info:
466519
shard_idx, total_shards = input_shard_info
467-
468-
# Get the actual number of inputs for this operator
469-
total_inputs = Operator(
470-
tb_args=tb_args, extra_args=unknown_args
471-
)._available_num_inputs
472-
473-
# Calculate shard boundaries
474-
inputs_per_shard = total_inputs // total_shards
475-
extra_inputs = total_inputs % total_shards
520+
521+
# Calculate shard boundaries on the selected indices
522+
num_selected = len(selected_indices)
523+
inputs_per_shard = num_selected // total_shards
524+
extra_inputs = num_selected % total_shards
476525

477526
if shard_idx <= extra_inputs:
478527
start_idx = (shard_idx - 1) * (inputs_per_shard + 1)
@@ -484,15 +533,21 @@ def _inner() -> Callable[..., Any] | object:
484533
)
485534
shard_size = inputs_per_shard
486535

536+
# Get the actual indices for this shard
537+
shard_indices = selected_indices[start_idx:start_idx + shard_size]
538+
487539
print(
488-
f"Running input shard {shard_idx}/{total_shards}: inputs {start_idx} to {start_idx + shard_size - 1} (of {total_inputs} total)",
540+
f"Step 2 - Sharding: shard {shard_idx}/{total_shards} gets {len(shard_indices)} inputs",
489541
file=sys.stderr,
490542
)
543+
print(f" Shard indices: {shard_indices}", file=sys.stderr)
544+
545+
# Update selected_indices to only include this shard
546+
selected_indices = shard_indices
491547

492-
# Add input-id and num-inputs to the tritonbench args before re-parsing
493-
tritonbench_args.extend(
494-
["--input-id", str(start_idx), "--num-inputs", str(shard_size)]
495-
)
548+
# Add the final selected indices to tritonbench args
549+
if selected_indices is not None and len(selected_indices) > 0 and len(selected_indices) < total_inputs:
550+
tritonbench_args.extend(["--input-id", ",".join(map(str, selected_indices))])
496551

497552
# Re-parse args with the new input range
498553
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
@@ -523,6 +578,13 @@ def main() -> None:
523578
type=str,
524579
help="Run only a subset of inputs for each kernel. Format: M/N where M is the shard number (1-indexed) and N is the total number of shards. For example, --input-shard 1/3 runs the first third of inputs for each kernel.",
525580
)
581+
parser.add_argument(
582+
"--input-sample-mode",
583+
type=str,
584+
choices=["first-n", "equal-spaced"],
585+
default="first-n",
586+
help="How to sample inputs when using --num-inputs. 'first-n' (default) takes the first X inputs. 'equal-spaced' takes X inputs equally spaced throughout the input list.",
587+
)
526588

527589
# Parse known args to get the kernel name, pass rest to tritonbench
528590
args, tritonbench_args = parser.parse_known_args()
@@ -568,7 +630,7 @@ def main() -> None:
568630

569631
# Run specified kernels
570632
if len(kernel_names) == 1:
571-
run_kernel(kernel_names[0], tritonbench_args, input_shard_info)
633+
run_kernel(kernel_names[0], tritonbench_args, input_shard_info, args.input_sample_mode)
572634
else:
573635
print(
574636
f"Running {len(kernel_names)} kernels: {', '.join(kernel_names)}...\n",
@@ -578,15 +640,15 @@ def main() -> None:
578640
print(f"\n{'=' * 60}", file=sys.stderr)
579641
print(f"Kernel: {kernel_name}", file=sys.stderr)
580642
print(f"{'=' * 60}\n", file=sys.stderr)
581-
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info)
643+
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info, args.input_sample_mode)
582644
else:
583645
# Run all kernels
584646
print(f"Running all {len(KERNEL_MAPPINGS)} kernels...\n", file=sys.stderr)
585647
for kernel_name in KERNEL_MAPPINGS:
586648
print(f"\n{'=' * 60}", file=sys.stderr)
587649
print(f"Kernel: {kernel_name}", file=sys.stderr)
588650
print(f"{'=' * 60}\n", file=sys.stderr)
589-
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info)
651+
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info, args.input_sample_mode)
590652

591653

592654
if __name__ == "__main__":

0 commit comments

Comments
 (0)