Skip to content

Commit 6c5c4ca

Browse files
authored
[Benchmark] Allow running a specific shard of input via --input-shard M/N cli arg (#377)
1 parent 462fc00 commit 6c5c4ca

File tree

1 file changed

+62
-4
lines changed

1 file changed

+62
-4
lines changed

benchmarks/run.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,11 @@ def check_and_setup_tritonbench() -> None:
205205
sys.exit(1)
206206

207207

208-
def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
208+
def run_kernel(
209+
kernel_name: str,
210+
tritonbench_args: list[str],
211+
input_shard_info: tuple[int, int] | None = None,
212+
) -> None:
209213
"""Run a single kernel benchmark."""
210214
# Check if kernel is in the mapping table
211215
if kernel_name not in KERNEL_MAPPINGS:
@@ -343,6 +347,36 @@ def _inner() -> Callable[..., Any] | object:
343347
# Create and run the operator with unknown args
344348
op = Operator(tb_args=tb_args, extra_args=unknown_args)
345349

350+
# Handle input sharding if requested
351+
if input_shard_info:
352+
shard_idx, total_shards = input_shard_info
353+
354+
# Get the actual number of inputs for this operator
355+
total_inputs = op._available_num_inputs
356+
357+
# Calculate shard boundaries
358+
inputs_per_shard = total_inputs // total_shards
359+
extra_inputs = total_inputs % total_shards
360+
361+
if shard_idx <= extra_inputs:
362+
start_idx = (shard_idx - 1) * (inputs_per_shard + 1)
363+
shard_size = inputs_per_shard + 1
364+
else:
365+
start_idx = (
366+
extra_inputs * (inputs_per_shard + 1)
367+
+ (shard_idx - 1 - extra_inputs) * inputs_per_shard
368+
)
369+
shard_size = inputs_per_shard
370+
371+
# Override the operator's input range
372+
op._input_id = start_idx
373+
op._num_inputs = shard_size
374+
375+
print(
376+
f"Running input shard {shard_idx}/{total_shards}: inputs {start_idx} to {start_idx + shard_size - 1} (of {total_inputs} total)",
377+
file=sys.stderr,
378+
)
379+
346380
# Run with proper parameters
347381
warmup = int(getattr(tb_args, "warmup", 25))
348382
rep = int(getattr(tb_args, "iter", 100))
@@ -369,13 +403,37 @@ def main() -> None:
369403
type=str,
370404
help="Name(s) of the Helion kernel module(s) to run. Can be a single kernel or comma-separated list (e.g., vector_add or vector_add,rms_norm). If not specified, runs all kernels.",
371405
)
406+
parser.add_argument(
407+
"--input-shard",
408+
type=str,
409+
help="Run only a subset of inputs for each kernel. Format: M/N where M is the shard number (1-indexed) and N is the total number of shards. For example, --input-shard 1/3 runs the first third of inputs for each kernel.",
410+
)
372411

373412
# Parse known args to get the kernel name, pass rest to tritonbench
374413
args, tritonbench_args = parser.parse_known_args()
375414

376415
# Check and setup tritonbench if needed
377416
check_and_setup_tritonbench()
378417

418+
# Store input-shard info for later processing
419+
input_shard_info = None
420+
if args.input_shard:
421+
try:
422+
shard_idx, total_shards = map(int, args.input_shard.split("/"))
423+
if shard_idx < 1 or shard_idx > total_shards:
424+
print(
425+
f"Error: Shard number {shard_idx} must be between 1 and {total_shards}",
426+
file=sys.stderr,
427+
)
428+
sys.exit(1)
429+
input_shard_info = (shard_idx, total_shards)
430+
except ValueError:
431+
print(
432+
f"Error: Invalid input-shard format '{args.input_shard}'. Expected format: M/N (e.g., 1/3)",
433+
file=sys.stderr,
434+
)
435+
sys.exit(1)
436+
379437
if args.kernel:
380438
# Parse comma-separated kernel names
381439
kernel_names = [k.strip() for k in args.kernel.split(",")]
@@ -395,7 +453,7 @@ def main() -> None:
395453

396454
# Run specified kernels
397455
if len(kernel_names) == 1:
398-
run_kernel(kernel_names[0], tritonbench_args)
456+
run_kernel(kernel_names[0], tritonbench_args, input_shard_info)
399457
else:
400458
print(
401459
f"Running {len(kernel_names)} kernels: {', '.join(kernel_names)}...\n",
@@ -405,15 +463,15 @@ def main() -> None:
405463
print(f"\n{'=' * 60}", file=sys.stderr)
406464
print(f"Kernel: {kernel_name}", file=sys.stderr)
407465
print(f"{'=' * 60}\n", file=sys.stderr)
408-
run_kernel(kernel_name, tritonbench_args.copy())
466+
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info)
409467
else:
410468
# Run all kernels
411469
print(f"Running all {len(KERNEL_MAPPINGS)} kernels...\n", file=sys.stderr)
412470
for kernel_name in KERNEL_MAPPINGS:
413471
print(f"\n{'=' * 60}", file=sys.stderr)
414472
print(f"Kernel: {kernel_name}", file=sys.stderr)
415473
print(f"{'=' * 60}\n", file=sys.stderr)
416-
run_kernel(kernel_name, tritonbench_args.copy())
474+
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info)
417475

418476

419477
if __name__ == "__main__":

0 commit comments

Comments
 (0)