@@ -271,6 +271,7 @@ def run_kernel(
271
271
272
272
# Extract operator args if present
273
273
operator_args = {}
274
+ only_shapes = None
274
275
275
276
# Normalize to list of variants format
276
277
if isinstance (mapping [1 ], list ):
@@ -279,15 +280,21 @@ def run_kernel(
279
280
variants = mapping [1 ]
280
281
# Check if last element is args dict
281
282
if len (mapping ) > 2 and isinstance (mapping [2 ], dict ):
282
- operator_args = mapping [2 ]
283
+ operator_args = mapping [2 ].copy ()
284
+ # Extract only_shapes if present
285
+ if "only_shapes" in operator_args :
286
+ only_shapes = operator_args .pop ("only_shapes" )
283
287
else :
284
288
# Single kernel format
285
289
if len (mapping ) == 4 and isinstance (mapping [3 ], dict ):
286
290
# With args
287
291
tritonbench_module = mapping [0 ]
288
292
module_path = mapping [1 ]
289
293
func_name = mapping [2 ]
290
- operator_args = mapping [3 ] # pyright: ignore[reportGeneralTypeIssues]
294
+ operator_args = mapping [3 ].copy () # pyright: ignore[reportGeneralTypeIssues]
295
+ # Extract only_shapes if present
296
+ if "only_shapes" in operator_args :
297
+ only_shapes = operator_args .pop ("only_shapes" )
291
298
variants = [(module_path , func_name )]
292
299
else :
293
300
# Without args
@@ -303,6 +310,7 @@ def run_kernel(
303
310
tritonbench_args ,
304
311
input_shard_info ,
305
312
operator_args ,
313
+ only_shapes ,
306
314
)
307
315
308
316
@@ -313,6 +321,7 @@ def run_kernel_variants(
313
321
tritonbench_args : list [str ],
314
322
input_shard_info : tuple [int , int ] | None = None ,
315
323
operator_args : dict [str , Any ] | None = None ,
324
+ only_shapes : list [str ] | None = None ,
316
325
) -> None :
317
326
"""Run kernel variants in the same benchmark run."""
318
327
@@ -377,6 +386,69 @@ def run_kernel_variants(
377
386
from tritonbench .utils .triton_op import ( # pyright: ignore[reportMissingImports]
378
387
register_benchmark ,
379
388
)
389
+
390
+ # Inject only_shapes filter if provided
391
+ if only_shapes :
392
+ print (f"Using only_shapes for { kernel_name } : { only_shapes } " , file = sys .stderr )
393
+
394
+ # Override the get_input_iter method for the operator class
395
+ original_get_input_iter = Operator .get_input_iter
396
+ original_get_x_val = Operator .get_x_val if hasattr (Operator , 'get_x_val' ) else None
397
+
398
+ # Create a list to store filtered inputs and their shapes
399
+ filtered_inputs = []
400
+
401
+ # First, collect all inputs that match the shape filter
402
+ temp_operator = Operator (tb_args = tb_args , extra_args = unknown_args )
403
+ for inputs in original_get_input_iter (temp_operator ):
404
+ # Get the shape value for this input
405
+ shape_value = None
406
+
407
+ if original_get_x_val :
408
+ # Use the operator's get_x_val method to get shape representation
409
+ shape_value = original_get_x_val (temp_operator , inputs )
410
+ else :
411
+ # Fallback: try to get shape from the inputs directly
412
+ if isinstance (inputs , tuple ) and len (inputs ) > 0 :
413
+ if hasattr (inputs [0 ], 'shape' ):
414
+ shape_value = list (inputs [0 ].shape )
415
+ elif isinstance (inputs [0 ], (int , float )):
416
+ shape_value = inputs [0 ]
417
+ else :
418
+ # For complex inputs, try to extract meaningful shape info
419
+ shape_value = inputs
420
+
421
+ # Check if this shape matches any in our filter using direct comparison
422
+ match_found = False
423
+ for expected_shape in only_shapes :
424
+ if shape_value == expected_shape :
425
+ match_found = True
426
+ break
427
+ # Also check if shape_value is a tuple/list that matches
428
+ elif isinstance (shape_value , (tuple , list )) and isinstance (expected_shape , (tuple , list )):
429
+ if len (shape_value ) == len (expected_shape ) and all (a == b for a , b in zip (shape_value , expected_shape )):
430
+ match_found = True
431
+ break
432
+
433
+ if match_found :
434
+ filtered_inputs .append (inputs )
435
+ print (f" Including shape: { shape_value } " , file = sys .stderr )
436
+
437
+ del temp_operator # Clean up temporary operator
438
+
439
+ if not filtered_inputs :
440
+ print (f"Warning: No shapes matched the filter for { kernel_name } " , file = sys .stderr )
441
+
442
+ def filtered_get_input_iter (self ):
443
+ """Custom input iterator that only yields filtered shapes."""
444
+ for inputs in filtered_inputs :
445
+ yield inputs
446
+
447
+ # Monkey-patch the operator class
448
+ Operator .get_input_iter = filtered_get_input_iter
449
+
450
+ # Also override _available_num_inputs for proper sharding support
451
+ Operator ._available_num_inputs = len (filtered_inputs )
380
452
381
453
# Register all variants as separate methods
382
454
for module_path , func_name in variants :
0 commit comments