@@ -347,6 +347,7 @@ def inject_profiling_into_existing_test(
347347 ast .Import (names = [ast .alias (name = "time" )]),
348348 ast .Import (names = [ast .alias (name = "gc" )]),
349349 ast .Import (names = [ast .alias (name = "os" )]),
350+ ast .Import (names = [ast .alias (name = "torch" )])
350351 ]
351352 if mode == TestingMode .BEHAVIOR :
352353 new_imports .extend (
@@ -524,70 +525,114 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
524525 ast .Try (
525526 body = [
526527 ast .Assign (
527- targets = [ast .Name (id = "counter" , ctx = ast .Store ())],
528+ targets = [
529+ ast .Name (id = 'start' , ctx = ast .Store ())],
528530 value = ast .Call (
529531 func = ast .Attribute (
530- value = ast .Name (id = "time" , ctx = ast .Load ()), attr = "perf_counter_ns" , ctx = ast .Load ()
531- ),
532+ value = ast .Attribute (
533+ value = ast .Name (id = 'torch' , ctx = ast .Load ()),
534+ attr = 'cuda' ,
535+ ctx = ast .Load ()),
536+ attr = 'Event' ,
537+ ctx = ast .Load ()),
532538 args = [],
533- keywords = [],
534- ),
535- lineno = lineno + 11 ,
536- ),
539+ keywords = [
540+ ast .keyword (
541+ arg = 'enable_timing' ,
542+ value = ast .Constant (value = True ))]), lineno = lineno + 11 ),
543+ ast .Assign (
544+ targets = [
545+ ast .Name (id = 'end' , ctx = ast .Store ())],
546+ value = ast .Call (
547+ func = ast .Attribute (
548+ value = ast .Attribute (
549+ value = ast .Name (id = 'torch' , ctx = ast .Load ()),
550+ attr = 'cuda' ,
551+ ctx = ast .Load ()),
552+ attr = 'Event' ,
553+ ctx = ast .Load ()),
554+ args = [],
555+ keywords = [
556+ ast .keyword (
557+ arg = 'enable_timing' ,
558+ value = ast .Constant (value = True ))]), lineno = lineno + 12 ),
559+ ast .Expr (
560+ value = ast .Call (
561+ func = ast .Attribute (
562+ value = ast .Name (id = 'start' , ctx = ast .Load ()),
563+ attr = 'record' ,
564+ ctx = ast .Load ()),
565+ args = [],
566+ keywords = []), lineno = lineno + 13 ),
537567 ast .Assign (
538568 targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
539569 value = ast .Call (
540570 func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
541571 args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
542572 keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
543573 ),
544- lineno = lineno + 12 ,
574+ lineno = lineno + 13 ,
545575 ),
576+ ast .Expr (
577+ value = ast .Call (
578+ func = ast .Attribute (
579+ value = ast .Name (id = 'end' , ctx = ast .Load ()),
580+ attr = 'record' ,
581+ ctx = ast .Load ()),
582+ args = [],
583+ keywords = []), lineno = lineno + 14 ),
584+ ast .Expr (
585+ value = ast .Call (
586+ func = ast .Attribute (
587+ value = ast .Attribute (
588+ value = ast .Name (id = 'torch' , ctx = ast .Load ()),
589+ attr = 'cuda' ,
590+ ctx = ast .Load ()),
591+ attr = 'synchronize' ,
592+ ctx = ast .Load ()),
593+ args = [],
594+ keywords = []),lineno = lineno + 15 ),
546595 ast .Assign (
547- targets = [ast .Name (id = "codeflash_duration" , ctx = ast .Store ())],
596+ targets = [
597+ ast .Name (id = 'codeflash_duration' , ctx = ast .Store ())],
548598 value = ast .BinOp (
549599 left = ast .Call (
550600 func = ast .Attribute (
551- value = ast .Name (id = "time" , ctx = ast .Load ()), attr = "perf_counter_ns" , ctx = ast .Load ()
552- ),
553- args = [],
554- keywords = [],
555- ),
556- op = ast .Sub (),
557- right = ast .Name (id = "counter" , ctx = ast .Load ()),
558- ),
559- lineno = lineno + 13 ,
560- ),
601+ value = ast .Name (id = 'start' , ctx = ast .Load ()),
602+ attr = 'elapsed_time' ,
603+ ctx = ast .Load ()),
604+ args = [
605+ ast .Name (id = 'end' , ctx = ast .Load ())],
606+ keywords = []),
607+ op = ast .Mult (),
608+ right = ast .Constant (value = 1000000 )), lineno = lineno + 16 ),
561609 ],
562610 handlers = [
563611 ast .ExceptHandler (
564612 type = ast .Name (id = "Exception" , ctx = ast .Load ()),
565613 name = "e" ,
566614 body = [
567615 ast .Assign (
568- targets = [ast .Name (id = "codeflash_duration" , ctx = ast .Store ())],
616+ targets = [
617+ ast .Name (id = 'codeflash_duration' , ctx = ast .Store ())],
569618 value = ast .BinOp (
570619 left = ast .Call (
571620 func = ast .Attribute (
572- value = ast .Name (id = "time" , ctx = ast .Load ()),
573- attr = "perf_counter_ns" ,
574- ctx = ast .Load (),
575- ),
576- args = [],
577- keywords = [],
578- ),
579- op = ast .Sub (),
580- right = ast .Name (id = "counter" , ctx = ast .Load ()),
581- ),
582- lineno = lineno + 15 ,
583- ),
621+ value = ast .Name (id = 'start' , ctx = ast .Load ()),
622+ attr = 'elapsed_time' ,
623+ ctx = ast .Load ()),
624+ args = [
625+ ast .Name (id = 'end' , ctx = ast .Load ())],
626+ keywords = []),
627+ op = ast .Mult (),
628+ right = ast .Constant (value = 1000000 )), lineno = lineno + 18 ),
584629 ast .Assign (
585630 targets = [ast .Name (id = "exception" , ctx = ast .Store ())],
586631 value = ast .Name (id = "e" , ctx = ast .Load ()),
587- lineno = lineno + 13 ,
632+ lineno = lineno + 16 ,
588633 ),
589634 ],
590- lineno = lineno + 14 ,
635+ lineno = lineno + 17 ,
591636 )
592637 ],
593638 orelse = [],
0 commit comments