Skip to content

Commit 87301a8

Browse files
author
Ubuntu
committed
cuda synchronize for existing test instrumentation
1 parent dfbad90 commit 87301a8

File tree

3 files changed

+93
-44
lines changed

3 files changed

+93
-44
lines changed
Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
from typing import Union
2+
13
import torch
24

3-
def sorter_cuda(arr: torch.Tensor)->torch.Tensor:
4-
arr = arr.cuda()
5+
def sorter_cuda(arr: list[float])->list[float]:
6+
arr1 = torch.randperm(10).cuda()
57
print("codeflash stdout: Sorting list")
6-
for i in range(arr.shape[0]):
7-
for j in range(arr.shape[0] - 1):
8-
if arr[j] > arr[j + 1]:
9-
temp = arr[j]
10-
arr[j] = arr[j + 1]
11-
arr[j + 1] = temp
8+
for i in range(arr1.shape[0]):
9+
for j in range(arr1.shape[0] - 1):
10+
if arr1[j] > arr1[j + 1]:
11+
temp = arr1[j]
12+
arr1[j] = arr1[j + 1]
13+
arr1[j + 1] = temp
1214
print(f"result: {arr}")
13-
return arr.cpu()
15+
arr1 = arr1.cpu()
16+
arr.sort()
17+
return arr

code_to_optimize/tests/pytest/test_bb_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from code_to_optimize.bubble_sort_cuda import sorter_cuda
22

33

4-
def test_sort():
4+
def test_sorter_cuda():
55
input = [5, 4, 3, 2, 1, 0]
66
output = sorter_cuda(input)
77
assert output == [0, 1, 2, 3, 4, 5]

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def inject_profiling_into_existing_test(
347347
ast.Import(names=[ast.alias(name="time")]),
348348
ast.Import(names=[ast.alias(name="gc")]),
349349
ast.Import(names=[ast.alias(name="os")]),
350+
ast.Import(names=[ast.alias(name="torch")])
350351
]
351352
if mode == TestingMode.BEHAVIOR:
352353
new_imports.extend(
@@ -524,70 +525,114 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
524525
ast.Try(
525526
body=[
526527
ast.Assign(
527-
targets=[ast.Name(id="counter", ctx=ast.Store())],
528+
targets=[
529+
ast.Name(id='start', ctx=ast.Store())],
528530
value=ast.Call(
529531
func=ast.Attribute(
530-
value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()
531-
),
532+
value=ast.Attribute(
533+
value=ast.Name(id='torch', ctx=ast.Load()),
534+
attr='cuda',
535+
ctx=ast.Load()),
536+
attr='Event',
537+
ctx=ast.Load()),
532538
args=[],
533-
keywords=[],
534-
),
535-
lineno=lineno + 11,
536-
),
539+
keywords=[
540+
ast.keyword(
541+
arg='enable_timing',
542+
value=ast.Constant(value=True))]), lineno=lineno + 11),
543+
ast.Assign(
544+
targets=[
545+
ast.Name(id='end', ctx=ast.Store())],
546+
value=ast.Call(
547+
func=ast.Attribute(
548+
value=ast.Attribute(
549+
value=ast.Name(id='torch', ctx=ast.Load()),
550+
attr='cuda',
551+
ctx=ast.Load()),
552+
attr='Event',
553+
ctx=ast.Load()),
554+
args=[],
555+
keywords=[
556+
ast.keyword(
557+
arg='enable_timing',
558+
value=ast.Constant(value=True))]), lineno=lineno + 12),
559+
ast.Expr(
560+
value=ast.Call(
561+
func=ast.Attribute(
562+
value=ast.Name(id='start', ctx=ast.Load()),
563+
attr='record',
564+
ctx=ast.Load()),
565+
args=[],
566+
keywords=[]), lineno=lineno+13),
537567
ast.Assign(
538568
targets=[ast.Name(id="return_value", ctx=ast.Store())],
539569
value=ast.Call(
540570
func=ast.Name(id="wrapped", ctx=ast.Load()),
541571
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
542572
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
543573
),
544-
lineno=lineno + 12,
574+
lineno=lineno + 13,
545575
),
576+
ast.Expr(
577+
value=ast.Call(
578+
func=ast.Attribute(
579+
value=ast.Name(id='end', ctx=ast.Load()),
580+
attr='record',
581+
ctx=ast.Load()),
582+
args=[],
583+
keywords=[]), lineno=lineno + 14),
584+
ast.Expr(
585+
value=ast.Call(
586+
func=ast.Attribute(
587+
value=ast.Attribute(
588+
value=ast.Name(id='torch', ctx=ast.Load()),
589+
attr='cuda',
590+
ctx=ast.Load()),
591+
attr='synchronize',
592+
ctx=ast.Load()),
593+
args=[],
594+
keywords=[]),lineno=lineno + 15),
546595
ast.Assign(
547-
targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())],
596+
targets=[
597+
ast.Name(id='codeflash_duration', ctx=ast.Store())],
548598
value=ast.BinOp(
549599
left=ast.Call(
550600
func=ast.Attribute(
551-
value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()
552-
),
553-
args=[],
554-
keywords=[],
555-
),
556-
op=ast.Sub(),
557-
right=ast.Name(id="counter", ctx=ast.Load()),
558-
),
559-
lineno=lineno + 13,
560-
),
601+
value=ast.Name(id='start', ctx=ast.Load()),
602+
attr='elapsed_time',
603+
ctx=ast.Load()),
604+
args=[
605+
ast.Name(id='end', ctx=ast.Load())],
606+
keywords=[]),
607+
op=ast.Mult(),
608+
right=ast.Constant(value=1000000)), lineno = lineno + 16),
561609
],
562610
handlers=[
563611
ast.ExceptHandler(
564612
type=ast.Name(id="Exception", ctx=ast.Load()),
565613
name="e",
566614
body=[
567615
ast.Assign(
568-
targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())],
616+
targets=[
617+
ast.Name(id='codeflash_duration', ctx=ast.Store())],
569618
value=ast.BinOp(
570619
left=ast.Call(
571620
func=ast.Attribute(
572-
value=ast.Name(id="time", ctx=ast.Load()),
573-
attr="perf_counter_ns",
574-
ctx=ast.Load(),
575-
),
576-
args=[],
577-
keywords=[],
578-
),
579-
op=ast.Sub(),
580-
right=ast.Name(id="counter", ctx=ast.Load()),
581-
),
582-
lineno=lineno + 15,
583-
),
621+
value=ast.Name(id='start', ctx=ast.Load()),
622+
attr='elapsed_time',
623+
ctx=ast.Load()),
624+
args=[
625+
ast.Name(id='end', ctx=ast.Load())],
626+
keywords=[]),
627+
op=ast.Mult(),
628+
right=ast.Constant(value=1000000)), lineno=lineno + 18),
584629
ast.Assign(
585630
targets=[ast.Name(id="exception", ctx=ast.Store())],
586631
value=ast.Name(id="e", ctx=ast.Load()),
587-
lineno=lineno + 13,
632+
lineno=lineno + 16,
588633
),
589634
],
590-
lineno=lineno + 14,
635+
lineno=lineno + 17,
591636
)
592637
],
593638
orelse=[],

0 commit comments

Comments
 (0)