@@ -281,6 +281,7 @@ def run_kernel(
281
281
282
282
# Extract operator args if present
283
283
operator_args = {}
284
+ only_shapes = None
284
285
285
286
# Normalize to list of variants format
286
287
if isinstance (mapping [1 ], list ):
@@ -289,15 +290,21 @@ def run_kernel(
289
290
variants = mapping [1 ]
290
291
# Check if last element is args dict
291
292
if len (mapping ) > 2 and isinstance (mapping [2 ], dict ):
292
- operator_args = mapping [2 ]
293
+ operator_args = mapping [2 ].copy ()
294
+ # Extract only_shapes if present
295
+ if "only_shapes" in operator_args :
296
+ only_shapes = operator_args .pop ("only_shapes" )
293
297
else :
294
298
# Single kernel format
295
299
if len (mapping ) == 4 and isinstance (mapping [3 ], dict ):
296
300
# With args
297
301
tritonbench_module = mapping [0 ]
298
302
module_path = mapping [1 ]
299
303
func_name = mapping [2 ]
300
- operator_args = mapping [3 ] # pyright: ignore[reportGeneralTypeIssues]
304
+ operator_args = mapping [3 ].copy () # pyright: ignore[reportGeneralTypeIssues]
305
+ # Extract only_shapes if present
306
+ if "only_shapes" in operator_args :
307
+ only_shapes = operator_args .pop ("only_shapes" )
301
308
variants = [(module_path , func_name )]
302
309
else :
303
310
# Without args
@@ -313,6 +320,7 @@ def run_kernel(
313
320
tritonbench_args ,
314
321
input_shard_info ,
315
322
operator_args ,
323
+ only_shapes ,
316
324
)
317
325
318
326
@@ -323,6 +331,7 @@ def run_kernel_variants(
323
331
tritonbench_args : list [str ],
324
332
input_shard_info : tuple [int , int ] | None = None ,
325
333
operator_args : dict [str , Any ] | None = None ,
334
+ only_shapes : list [str ] | None = None ,
326
335
) -> None :
327
336
"""Run kernel variants in the same benchmark run."""
328
337
@@ -374,6 +383,69 @@ def run_kernel_variants(
374
383
from tritonbench .utils .triton_op import ( # pyright: ignore[reportMissingImports]
375
384
register_benchmark ,
376
385
)
386
+
387
+ # Inject only_shapes filter if provided
388
+ if only_shapes :
389
+ print (f"Using only_shapes for { kernel_name } : { only_shapes } " , file = sys .stderr )
390
+
391
+ # Override the get_input_iter method for the operator class
392
+ original_get_input_iter = Operator .get_input_iter
393
+ original_get_x_val = Operator .get_x_val if hasattr (Operator , 'get_x_val' ) else None
394
+
395
+ # Create a list to store filtered inputs and their shapes
396
+ filtered_inputs = []
397
+
398
+ # First, collect all inputs that match the shape filter
399
+ temp_operator = Operator (tb_args = tb_args , extra_args = unknown_args )
400
+ for inputs in original_get_input_iter (temp_operator ):
401
+ # Get the shape value for this input
402
+ shape_value = None
403
+
404
+ if original_get_x_val :
405
+ # Use the operator's get_x_val method to get shape representation
406
+ shape_value = original_get_x_val (temp_operator , inputs )
407
+ else :
408
+ # Fallback: try to get shape from the inputs directly
409
+ if isinstance (inputs , tuple ) and len (inputs ) > 0 :
410
+ if hasattr (inputs [0 ], 'shape' ):
411
+ shape_value = list (inputs [0 ].shape )
412
+ elif isinstance (inputs [0 ], (int , float )):
413
+ shape_value = inputs [0 ]
414
+ else :
415
+ # For complex inputs, try to extract meaningful shape info
416
+ shape_value = inputs
417
+
418
+ # Check if this shape matches any in our filter using direct comparison
419
+ match_found = False
420
+ for expected_shape in only_shapes :
421
+ if shape_value == expected_shape :
422
+ match_found = True
423
+ break
424
+ # Also check if shape_value is a tuple/list that matches
425
+ elif isinstance (shape_value , (tuple , list )) and isinstance (expected_shape , (tuple , list )):
426
+ if len (shape_value ) == len (expected_shape ) and all (a == b for a , b in zip (shape_value , expected_shape )):
427
+ match_found = True
428
+ break
429
+
430
+ if match_found :
431
+ filtered_inputs .append (inputs )
432
+ print (f" Including shape: { shape_value } " , file = sys .stderr )
433
+
434
+ del temp_operator # Clean up temporary operator
435
+
436
+ if not filtered_inputs :
437
+ print (f"Warning: No shapes matched the filter for { kernel_name } " , file = sys .stderr )
438
+
439
+ def filtered_get_input_iter (self ):
440
+ """Custom input iterator that only yields filtered shapes."""
441
+ for inputs in filtered_inputs :
442
+ yield inputs
443
+
444
+ # Monkey-patch the operator class
445
+ Operator .get_input_iter = filtered_get_input_iter
446
+
447
+ # Also override _available_num_inputs for proper sharding support
448
+ Operator ._available_num_inputs = len (filtered_inputs )
377
449
378
450
# Register all variants as separate methods
379
451
for module_path , func_name in variants :
0 commit comments