Skip to content

Commit b5b7658

Browse files
committed
add only_shapes filtering in KERNEL_MAPPINGS
1 parent 09382b0 commit b5b7658

File tree

1 file changed

+74
-2
lines changed

1 file changed

+74
-2
lines changed

benchmarks/run.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def run_kernel(
271271

272272
# Extract operator args if present
273273
operator_args = {}
274+
only_shapes = None
274275

275276
# Normalize to list of variants format
276277
if isinstance(mapping[1], list):
@@ -279,15 +280,21 @@ def run_kernel(
279280
variants = mapping[1]
280281
# Check if last element is args dict
281282
if len(mapping) > 2 and isinstance(mapping[2], dict):
282-
operator_args = mapping[2]
283+
operator_args = mapping[2].copy()
284+
# Extract only_shapes if present
285+
if "only_shapes" in operator_args:
286+
only_shapes = operator_args.pop("only_shapes")
283287
else:
284288
# Single kernel format
285289
if len(mapping) == 4 and isinstance(mapping[3], dict):
286290
# With args
287291
tritonbench_module = mapping[0]
288292
module_path = mapping[1]
289293
func_name = mapping[2]
290-
operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues]
294+
operator_args = mapping[3].copy() # pyright: ignore[reportGeneralTypeIssues]
295+
# Extract only_shapes if present
296+
if "only_shapes" in operator_args:
297+
only_shapes = operator_args.pop("only_shapes")
291298
variants = [(module_path, func_name)]
292299
else:
293300
# Without args
@@ -303,6 +310,7 @@ def run_kernel(
303310
tritonbench_args,
304311
input_shard_info,
305312
operator_args,
313+
only_shapes,
306314
)
307315

308316

@@ -313,6 +321,7 @@ def run_kernel_variants(
313321
tritonbench_args: list[str],
314322
input_shard_info: tuple[int, int] | None = None,
315323
operator_args: dict[str, Any] | None = None,
324+
only_shapes: list[str] | None = None,
316325
) -> None:
317326
"""Run kernel variants in the same benchmark run."""
318327

@@ -377,6 +386,69 @@ def run_kernel_variants(
377386
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
378387
register_benchmark,
379388
)
389+
390+
# Inject only_shapes filter if provided
391+
if only_shapes:
392+
print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr)
393+
394+
# Override the get_input_iter method for the operator class
395+
original_get_input_iter = Operator.get_input_iter
396+
original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None
397+
398+
# Create a list to store filtered inputs and their shapes
399+
filtered_inputs = []
400+
401+
# First, collect all inputs that match the shape filter
402+
temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args)
403+
for inputs in original_get_input_iter(temp_operator):
404+
# Get the shape value for this input
405+
shape_value = None
406+
407+
if original_get_x_val:
408+
# Use the operator's get_x_val method to get shape representation
409+
shape_value = original_get_x_val(temp_operator, inputs)
410+
else:
411+
# Fallback: try to get shape from the inputs directly
412+
if isinstance(inputs, tuple) and len(inputs) > 0:
413+
if hasattr(inputs[0], 'shape'):
414+
shape_value = list(inputs[0].shape)
415+
elif isinstance(inputs[0], (int, float)):
416+
shape_value = inputs[0]
417+
else:
418+
# For complex inputs, try to extract meaningful shape info
419+
shape_value = inputs
420+
421+
# Check if this shape matches any in our filter using direct comparison
422+
match_found = False
423+
for expected_shape in only_shapes:
424+
if shape_value == expected_shape:
425+
match_found = True
426+
break
427+
# Also check if shape_value is a tuple/list that matches
428+
elif isinstance(shape_value, (tuple, list)) and isinstance(expected_shape, (tuple, list)):
429+
if len(shape_value) == len(expected_shape) and all(a == b for a, b in zip(shape_value, expected_shape)):
430+
match_found = True
431+
break
432+
433+
if match_found:
434+
filtered_inputs.append(inputs)
435+
print(f" Including shape: {shape_value}", file=sys.stderr)
436+
437+
del temp_operator # Clean up temporary operator
438+
439+
if not filtered_inputs:
440+
print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr)
441+
442+
def filtered_get_input_iter(self):
443+
"""Custom input iterator that only yields filtered shapes."""
444+
for inputs in filtered_inputs:
445+
yield inputs
446+
447+
# Monkey-patch the operator class
448+
Operator.get_input_iter = filtered_get_input_iter
449+
450+
# Also override _available_num_inputs for proper sharding support
451+
Operator._available_num_inputs = len(filtered_inputs)
380452

381453
# Register all variants as separate methods
382454
for module_path, func_name in variants:

0 commit comments

Comments
 (0)