Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 313 additions & 2 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,8 @@ def inject_profiling_into_existing_test(
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
mode: TestingMode = TestingMode.BEHAVIOR,
*,
jit_warmup: bool = False,
) -> tuple[bool, str | None]:
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
Expand Down Expand Up @@ -704,13 +706,277 @@ def inject_profiling_into_existing_test(
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
]
)
additional_functions = [create_wrapper_function(mode)]
additional_functions = [create_wrapper_function(mode, jit_warmup=jit_warmup)]
if jit_warmup:
additional_functions.insert(0, create_jit_sync_helper())

tree.body = [*new_imports, *additional_functions, *tree.body]
return True, sort_imports(ast.unparse(tree), float_to_top=True)


def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef:
def create_jit_sync_helper() -> ast.FunctionDef:
"""Create a helper function that synchronizes JIT-compiled frameworks (PyTorch, TensorFlow, JAX, MLX).

This function generates AST for:
def _codeflash_jit_sync():
try:
import torch
if torch.cuda.is_available():
torch.cuda.synchronize()
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
torch.mps.synchronize()
except ImportError:
pass
try:
import jax
# Block until all JAX computations are complete
jax.effects_barrier()
except ImportError:
pass
try:
import mlx.core as mx
mx.synchronize()
except ImportError:
pass
# Note: TensorFlow in eager mode auto-syncs; Numba JIT is CPU-based and doesn't need sync
"""
lineno = 1

# PyTorch sync block
pytorch_sync = ast.Try(
body=[
ast.Import(names=[ast.alias(name="torch")], lineno=lineno),
# if torch.cuda.is_available(): torch.cuda.synchronize()
ast.If(
test=ast.Call(
func=ast.Attribute(
value=ast.Attribute(value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load()),
attr="is_available",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
body=[
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load()
),
attr="synchronize",
ctx=ast.Load(),
),
args=[],
keywords=[],
)
)
],
orelse=[],
lineno=lineno,
),
# if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize()
ast.If(
test=ast.BoolOp(
op=ast.And(),
values=[
ast.Call(
func=ast.Name(id="hasattr", ctx=ast.Load()),
args=[
ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load()
),
ast.Constant(value="mps"),
],
keywords=[],
),
ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load()
),
attr="mps",
ctx=ast.Load(),
),
attr="is_available",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
],
),
body=[
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()), attr="mps", ctx=ast.Load()
),
attr="synchronize",
ctx=ast.Load(),
),
args=[],
keywords=[],
)
)
],
orelse=[],
lineno=lineno,
),
],
handlers=[
ast.ExceptHandler(
type=ast.Name(id="ImportError", ctx=ast.Load()),
name=None,
body=[ast.Pass(lineno=lineno)],
lineno=lineno,
)
],
orelse=[],
finalbody=[],
lineno=lineno,
)

# JAX sync block - use effects_barrier() to wait for all computations
jax_sync = ast.Try(
body=[
ast.Import(names=[ast.alias(name="jax")], lineno=lineno),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="jax", ctx=ast.Load()), attr="effects_barrier", ctx=ast.Load()
),
args=[],
keywords=[],
)
),
],
handlers=[
ast.ExceptHandler(
type=ast.Name(id="ImportError", ctx=ast.Load()),
name=None,
body=[ast.Pass(lineno=lineno)],
lineno=lineno,
)
],
orelse=[],
finalbody=[],
lineno=lineno,
)

# MLX sync block
mlx_sync = ast.Try(
body=[
ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno),
ast.Expr(
value=ast.Call(
func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()),
args=[],
keywords=[],
)
),
],
handlers=[
ast.ExceptHandler(
type=ast.Name(id="ImportError", ctx=ast.Load()),
name=None,
body=[ast.Pass(lineno=lineno)],
lineno=lineno,
)
],
orelse=[],
finalbody=[],
lineno=lineno,
)

# TensorFlow sync block - sync XLA/TPU devices
tensorflow_sync = ast.Try(
body=[
ast.Import(names=[ast.alias(name="tensorflow", asname="tf")], lineno=lineno),
# For TPU: tf.tpu.experimental.initialize_tpu_system if available
# For GPU: operations complete synchronously in eager mode but we can force sync
ast.If(
test=ast.Call(
func=ast.Name(id="hasattr", ctx=ast.Load()),
args=[
ast.Attribute(value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load()),
ast.Constant(value="experimental"),
],
keywords=[],
),
body=[
# Get all physical devices and sync GPUs
ast.For(
target=ast.Name(id="_device", ctx=ast.Store()),
iter=ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load()
),
attr="list_physical_devices",
ctx=ast.Load(),
),
args=[ast.Constant(value="GPU")],
keywords=[],
),
body=[
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="tf", ctx=ast.Load()), attr="test", ctx=ast.Load()
),
attr="experimental",
ctx=ast.Load(),
),
attr="sync_devices",
ctx=ast.Load(),
),
args=[],
keywords=[],
)
)
],
orelse=[],
lineno=lineno,
)
],
orelse=[],
lineno=lineno,
),
],
handlers=[
ast.ExceptHandler(
type=ast.Tuple(
elts=[ast.Name(id="ImportError", ctx=ast.Load()), ast.Name(id="AttributeError", ctx=ast.Load())],
ctx=ast.Load(),
),
name=None,
body=[ast.Pass(lineno=lineno)],
lineno=lineno,
)
],
orelse=[],
finalbody=[],
lineno=lineno,
)

return ast.FunctionDef(
name="_codeflash_jit_sync",
args=ast.arguments(
args=[], vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[]
),
body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditionally contain only one of them

decorator_list=[],
returns=None,
lineno=lineno,
)


def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_warmup: bool = False) -> ast.FunctionDef:
lineno = 1
wrapper_body: list[ast.stmt] = [
ast.Assign(
Expand Down Expand Up @@ -871,6 +1137,25 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
ast.Assign(
targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10
),
# JIT warmup: call function once to trigger JIT compilation before timing
*(
[
ast.Expr(
value=ast.Call(
func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()),
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
),
lineno=lineno + 10,
),
ast.Expr(
value=ast.Call(func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[]),
lineno=lineno + 10,
),
]
if jit_warmup
else []
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()),
Expand All @@ -881,6 +1166,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
),
ast.Try(
body=[
# Sync before starting timer (ensure previous operations are complete)
*(
[
ast.Expr(
value=ast.Call(
func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[]
),
lineno=lineno + 11,
)
]
if jit_warmup
else []
),
ast.Assign(
targets=[ast.Name(id="counter", ctx=ast.Store())],
value=ast.Call(
Expand All @@ -901,6 +1199,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
),
lineno=lineno + 12,
),
# Sync after function call to ensure all GPU/async operations complete before stopping timer
*(
[
ast.Expr(
value=ast.Call(
func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[]
),
lineno=lineno + 12,
)
]
if jit_warmup
else []
),
ast.Assign(
targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())],
value=ast.BinOp(
Expand Down
Loading