@@ -26,7 +26,7 @@ class Pattern:
2626 In subclass, define description and skip property.
2727 """
2828
29- def __init__ (self , prof : profile , should_benchmark : bool = False ):
29+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
3030 self .prof = prof
3131 self .should_benchmark = should_benchmark
3232 self .name = "Please specify a name for pattern"
@@ -39,7 +39,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
3939 self .tid_root .setdefault (event .start_tid , []).append (event )
4040
4141 @property
42- def skip (self ):
42+ def skip (self ) -> bool :
4343 return False
4444
4545 def report (self , event : _ProfilerEvent ):
@@ -66,8 +66,8 @@ def summary(self, events: list[_ProfilerEvent]):
6666 )
6767 return default_summary
6868
69- def benchmark_summary (self , events : list [_ProfilerEvent ]):
70- def format_time (time_ns : int ):
69+ def benchmark_summary (self , events : list [_ProfilerEvent ]) -> str :
70+ def format_time (time_ns : int ) -> str :
7171 unit_lst = ["ns" , "us" , "ms" ]
7272 for unit in unit_lst :
7373 if time_ns < 1000 :
@@ -135,7 +135,9 @@ def go_up_until(self, event: _ProfilerEvent, predicate):
135135
136136
137137class NamePattern (Pattern ):
138- def __init__ (self , prof : profile , name : str , should_benchmark : bool = False ):
138+ def __init__ (
139+ self , prof : profile , name : str , should_benchmark : bool = False
140+ ) -> None :
139141 super ().__init__ (prof , should_benchmark )
140142 self .description = f"Matched Name Event: { name } "
141143 self .name = name
@@ -161,7 +163,7 @@ class ExtraCUDACopyPattern(Pattern):
161163 If at any step we failed, it is not a match.
162164 """
163165
164- def __init__ (self , prof : profile , should_benchmark : bool = False ):
166+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
165167 super ().__init__ (prof , should_benchmark )
166168 self .name = "Extra CUDA Copy Pattern"
167169 self .description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
@@ -174,7 +176,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
174176 }
175177
176178 @property
177- def skip (self ):
179+ def skip (self ) -> bool :
178180 return not self .prof .with_stack or not self .prof .record_shapes
179181
180182 def match (self , event ):
@@ -248,7 +250,7 @@ class ForLoopIndexingPattern(Pattern):
248250 We also keep a dictionary to avoid duplicate match in the for loop.
249251 """
250252
251- def __init__ (self , prof : profile , should_benchmark : bool = False ):
253+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
252254 super ().__init__ (prof , should_benchmark )
253255 self .name = "For Loop Indexing Pattern"
254256 self .description = "For loop indexing detected. Vectorization recommended."
@@ -271,7 +273,7 @@ def match(self, event: _ProfilerEvent):
271273 return False
272274
273275 # Custom event list matching
274- def same_ops (list1 , list2 ):
276+ def same_ops (list1 , list2 ) -> bool :
275277 if len (list1 ) != len (list2 ):
276278 return False
277279 for op1 , op2 in zip (list1 , list2 ):
@@ -295,7 +297,7 @@ def same_ops(list1, list2):
295297
296298
297299class FP32MatMulPattern (Pattern ):
298- def __init__ (self , prof : profile , should_benchmark : bool = False ):
300+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
299301 super ().__init__ (prof , should_benchmark )
300302 self .name = "FP32 MatMul Pattern"
301303 self .description = (
@@ -316,7 +318,7 @@ def skip(self):
316318 )
317319 return has_tf32 is False or super ().skip or not self .prof .record_shapes
318320
319- def match (self , event : _ProfilerEvent ):
321+ def match (self , event : _ProfilerEvent ) -> bool :
320322 # If we saw this pattern once, we don't need to match it again
321323 if event .tag != _EventType .TorchOp :
322324 return False
@@ -365,7 +367,7 @@ class OptimizerSingleTensorPattern(Pattern):
365367 String match
366368 """
367369
368- def __init__ (self , prof : profile , should_benchmark : bool = False ):
370+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
369371 super ().__init__ (prof , should_benchmark )
370372 self .name = "Optimizer Single Tensor Pattern"
371373 self .optimizers_with_foreach = ["adam" , "sgd" , "adamw" ]
@@ -375,7 +377,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
375377 )
376378 self .url = ""
377379
378- def match (self , event : _ProfilerEvent ):
380+ def match (self , event : _ProfilerEvent ) -> bool :
379381 for optimizer in self .optimizers_with_foreach :
380382 if event .name .endswith (f"_single_tensor_{ optimizer } " ):
381383 return True
@@ -400,7 +402,7 @@ class SynchronizedDataLoaderPattern(Pattern):
400402
401403 """
402404
403- def __init__ (self , prof : profile , should_benchmark : bool = False ):
405+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
404406 super ().__init__ (prof , should_benchmark )
405407 self .name = "Synchronized DataLoader Pattern"
406408 self .description = (
@@ -412,7 +414,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
412414 "#enable-async-data-loading-and-augmentation"
413415 )
414416
415- def match (self , event : _ProfilerEvent ):
417+ def match (self , event : _ProfilerEvent ) -> bool :
416418 def is_dataloader_function (name : str , function_name : str ):
417419 return name .startswith (
418420 os .path .join ("torch" , "utils" , "data" , "dataloader.py" )
@@ -459,7 +461,7 @@ class GradNotSetToNonePattern(Pattern):
459461 String match
460462 """
461463
462- def __init__ (self , prof : profile , should_benchmark : bool = False ):
464+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
463465 super ().__init__ (prof , should_benchmark )
464466 self .name = "Gradient Set To Zero Instead of None Pattern"
465467 self .description = (
@@ -471,7 +473,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
471473 "#disable-gradient-calculation-for-validation-or-inference"
472474 )
473475
474- def match (self , event : _ProfilerEvent ):
476+ def match (self , event : _ProfilerEvent ) -> bool :
475477 if not event .name .endswith (": zero_grad" ):
476478 return False
477479 if not event .children :
@@ -500,7 +502,7 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
500502 String match
501503 """
502504
503- def __init__ (self , prof : profile , should_benchmark : bool = False ):
505+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
504506 super ().__init__ (prof , should_benchmark )
505507 self .name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
506508 self .description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
@@ -531,17 +533,17 @@ def match(self, event: _ProfilerEvent):
531533
532534
533535class MatMulDimInFP16Pattern (Pattern ):
534- def __init__ (self , prof : profile , should_benchmark : bool = False ):
536+ def __init__ (self , prof : profile , should_benchmark : bool = False ) -> None :
535537 super ().__init__ (prof , should_benchmark )
536538 self .name = "Matrix Multiplication Dimension Not Aligned Pattern"
537539 self .description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
538540 self .url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
539541
540542 @property
541- def skip (self ):
543+ def skip (self ) -> bool :
542544 return not self .prof .with_stack or not self .prof .record_shapes
543545
544- def match (self , event : _ProfilerEvent ):
546+ def match (self , event : _ProfilerEvent ) -> bool :
545547 def mutiple_of (shapes , multiple ):
546548 return all (dim % multiple == 0 for shape in shapes for dim in shape [- 2 :])
547549
@@ -584,7 +586,7 @@ def closest_multiple(shapes, multiple):
584586 return shapes_factor_map
585587
586588
587- def source_code_location (event : Optional [_ProfilerEvent ]):
589+ def source_code_location (event : Optional [_ProfilerEvent ]) -> str :
588590 while event :
589591 if event .tag == _EventType .PyCall or event .tag == _EventType .PyCCall :
590592 assert isinstance (
@@ -611,7 +613,7 @@ def report_all_anti_patterns(
611613 should_benchmark : bool = False ,
612614 print_enable : bool = True ,
613615 json_report_dir : Optional [str ] = None ,
614- ):
616+ ) -> None :
615617 report_dict : dict = {}
616618 anti_patterns = [
617619 ExtraCUDACopyPattern (prof , should_benchmark ),
0 commit comments