Skip to content

Commit db2988b

Browse files
committed
add only_shapes filtering in KERNEL_MAPPINGS
1 parent 27fea44 commit db2988b

File tree

1 file changed

+74
-2
lines changed

1 file changed

+74
-2
lines changed

benchmarks/run.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def run_kernel(
281281

282282
# Extract operator args if present
283283
operator_args = {}
284+
only_shapes = None
284285

285286
# Normalize to list of variants format
286287
if isinstance(mapping[1], list):
@@ -289,15 +290,21 @@ def run_kernel(
289290
variants = mapping[1]
290291
# Check if last element is args dict
291292
if len(mapping) > 2 and isinstance(mapping[2], dict):
292-
operator_args = mapping[2]
293+
operator_args = mapping[2].copy()
294+
# Extract only_shapes if present
295+
if "only_shapes" in operator_args:
296+
only_shapes = operator_args.pop("only_shapes")
293297
else:
294298
# Single kernel format
295299
if len(mapping) == 4 and isinstance(mapping[3], dict):
296300
# With args
297301
tritonbench_module = mapping[0]
298302
module_path = mapping[1]
299303
func_name = mapping[2]
300-
operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues]
304+
operator_args = mapping[3].copy() # pyright: ignore[reportGeneralTypeIssues]
305+
# Extract only_shapes if present
306+
if "only_shapes" in operator_args:
307+
only_shapes = operator_args.pop("only_shapes")
301308
variants = [(module_path, func_name)]
302309
else:
303310
# Without args
@@ -313,6 +320,7 @@ def run_kernel(
313320
tritonbench_args,
314321
input_shard_info,
315322
operator_args,
323+
only_shapes,
316324
)
317325

318326

@@ -323,6 +331,7 @@ def run_kernel_variants(
323331
tritonbench_args: list[str],
324332
input_shard_info: tuple[int, int] | None = None,
325333
operator_args: dict[str, Any] | None = None,
334+
only_shapes: list[str] | None = None,
326335
) -> None:
327336
"""Run kernel variants in the same benchmark run."""
328337

@@ -374,6 +383,69 @@ def run_kernel_variants(
374383
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
375384
register_benchmark,
376385
)
386+
387+
# Inject only_shapes filter if provided
388+
if only_shapes:
389+
print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr)
390+
391+
# Override the get_input_iter method for the operator class
392+
original_get_input_iter = Operator.get_input_iter
393+
original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None
394+
395+
# Create a list to store filtered inputs and their shapes
396+
filtered_inputs = []
397+
398+
# First, collect all inputs that match the shape filter
399+
temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args)
400+
for inputs in original_get_input_iter(temp_operator):
401+
# Get the shape value for this input
402+
shape_value = None
403+
404+
if original_get_x_val:
405+
# Use the operator's get_x_val method to get shape representation
406+
shape_value = original_get_x_val(temp_operator, inputs)
407+
else:
408+
# Fallback: try to get shape from the inputs directly
409+
if isinstance(inputs, tuple) and len(inputs) > 0:
410+
if hasattr(inputs[0], 'shape'):
411+
shape_value = list(inputs[0].shape)
412+
elif isinstance(inputs[0], (int, float)):
413+
shape_value = inputs[0]
414+
else:
415+
# For complex inputs, try to extract meaningful shape info
416+
shape_value = inputs
417+
418+
# Check if this shape matches any in our filter using direct comparison
419+
match_found = False
420+
for expected_shape in only_shapes:
421+
if shape_value == expected_shape:
422+
match_found = True
423+
break
424+
# Also check if shape_value is a tuple/list that matches
425+
elif isinstance(shape_value, (tuple, list)) and isinstance(expected_shape, (tuple, list)):
426+
if len(shape_value) == len(expected_shape) and all(a == b for a, b in zip(shape_value, expected_shape)):
427+
match_found = True
428+
break
429+
430+
if match_found:
431+
filtered_inputs.append(inputs)
432+
print(f" Including shape: {shape_value}", file=sys.stderr)
433+
434+
del temp_operator # Clean up temporary operator
435+
436+
if not filtered_inputs:
437+
print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr)
438+
439+
def filtered_get_input_iter(self):
440+
"""Custom input iterator that only yields filtered shapes."""
441+
for inputs in filtered_inputs:
442+
yield inputs
443+
444+
# Monkey-patch the operator class
445+
Operator.get_input_iter = filtered_get_input_iter
446+
447+
# Also override _available_num_inputs for proper sharding support
448+
Operator._available_num_inputs = len(filtered_inputs)
377449

378450
# Register all variants as separate methods
379451
for module_path, func_name in variants:

0 commit comments

Comments
 (0)