Skip to content

Commit dfc3ba0

Browse files
PaliCmsaroufimbertmaher
authored
Add tests for serialization and deserialization (#49)
Co-authored-by: Mark Saroufim <[email protected]> Co-authored-by: Bert Maher <[email protected]>
1 parent 2037e53 commit dfc3ba0

File tree

3 files changed

+481
-55
lines changed

3 files changed

+481
-55
lines changed

BackendBench/torchbench_suite.py

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,68 +2,20 @@
22
Load aten inputs from serialized txt files.
33
"""
44

5-
import math
65
import re
76
import tempfile
87
from collections import defaultdict
98
from pathlib import Path
109

1110
import requests
1211
import torch
13-
from torch.testing import make_tensor
12+
from BackendBench.utils import deserialize_args
1413

1514
# the schema for this dataset is the one defined in tritonbench traces.
1615
# ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt
1716
DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt"
1817

1918

20-
dtype_abbrs = {
21-
torch.bfloat16: "bf16",
22-
torch.float64: "f64",
23-
torch.float32: "f32",
24-
torch.float16: "f16",
25-
torch.complex32: "c32",
26-
torch.complex64: "c64",
27-
torch.complex128: "c128",
28-
torch.int8: "i8",
29-
torch.int16: "i16",
30-
torch.int32: "i32",
31-
torch.int64: "i64",
32-
torch.bool: "b8",
33-
torch.uint8: "u8",
34-
}
35-
36-
dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}
37-
38-
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
39-
40-
41-
def _deserialize_tensor(size, dtype, stride=None, device="cuda"):
42-
kwargs = {}
43-
if dtype in _FLOATING_TYPES:
44-
kwargs.update({"low": 0, "high": 1})
45-
if stride is not None:
46-
extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride))
47-
data = make_tensor(extent, dtype=dtype, device=device, **kwargs)
48-
return data.as_strided(size, stride)
49-
return make_tensor(size, dtype=dtype, device=device, **kwargs)
50-
51-
52-
def _deserialize_args(inps):
53-
inps = inps.strip().strip("'")
54-
global_vals = {
55-
"T": _deserialize_tensor,
56-
"th": torch,
57-
"inf": math.inf,
58-
"torch": torch,
59-
**dtype_abbrs_parsing,
60-
}
61-
# f strings introduce quotations we dont want
62-
for key in dtype_abbrs_parsing:
63-
inps = inps.replace(f"'{key}'", key)
64-
return eval(inps.strip().strip("'").strip('"'), global_vals)
65-
66-
6719
class TorchBenchTest:
6820
def __init__(self, *args, **kwargs):
6921
self.args = args
@@ -89,7 +41,7 @@ def __init__(self, op, inputs, topn):
8941
def tests(self):
9042
inputs_and_sizes = []
9143
for inp in self.inputs:
92-
args, kwargs = _deserialize_args(inp)
44+
args, kwargs = deserialize_args(inp)
9345
size = _args_size(args) + _args_size(list(kwargs.values()))
9446
inputs_and_sizes.append((size, inp))
9547
ret = [x[1] for x in sorted(inputs_and_sizes, reverse=True)]
@@ -98,13 +50,13 @@ def tests(self):
9850
@property
9951
def correctness_tests(self):
10052
for inp in self.tests():
101-
args, kwargs = _deserialize_args(inp)
53+
args, kwargs = deserialize_args(inp)
10254
yield TorchBenchTest(*args, **kwargs)
10355

10456
@property
10557
def performance_tests(self):
10658
for inp in self.tests():
107-
args, kwargs = _deserialize_args(inp)
59+
args, kwargs = deserialize_args(inp)
10860
yield TorchBenchTest(*args, **kwargs)
10961

11062

BackendBench/utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,29 @@
22
import inspect
33
import re
44
import textwrap
5+
import math
6+
import torch
7+
from torch.testing import make_tensor
8+
9+
dtype_abbrs = {
10+
torch.bfloat16: "bf16",
11+
torch.float64: "f64",
12+
torch.float32: "f32",
13+
torch.float16: "f16",
14+
torch.complex32: "c32",
15+
torch.complex64: "c64",
16+
torch.complex128: "c128",
17+
torch.int8: "i8",
18+
torch.int16: "i16",
19+
torch.int32: "i32",
20+
torch.int64: "i64",
21+
torch.bool: "b8",
22+
torch.uint8: "u8",
23+
}
24+
25+
dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}
26+
27+
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
528

629

730
def uses_cuda_stream(func) -> bool:
@@ -51,3 +74,82 @@ def visit_Call(self, node):
5174
finder = StreamCreationFinder()
5275
finder.visit(tree)
5376
return finder.found
77+
78+
79+
def _deserialize_tensor(size, dtype, stride=None, device="cuda"):
80+
kwargs = {}
81+
if dtype in _FLOATING_TYPES:
82+
kwargs.update({"low": 0, "high": 1})
83+
84+
# Fall back to CPU if CUDA is not available
85+
if device == "cuda" and not torch.cuda.is_available():
86+
device = "cpu"
87+
88+
if stride is not None:
89+
extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride))
90+
data = make_tensor(extent, dtype=dtype, device=device, **kwargs)
91+
return data.as_strided(size, stride)
92+
return make_tensor(size, dtype=dtype, device=device, **kwargs)
93+
94+
95+
def _serialize_tensor(tensor):
96+
"""Helper function to serialize a tensor to string format"""
97+
shape = list(tensor.shape)
98+
dtype = dtype_abbrs[tensor.dtype]
99+
stride = tensor.stride() if not tensor.is_contiguous() else None
100+
101+
if stride:
102+
return f"T({shape}, {dtype}, {list(stride)})"
103+
else:
104+
return f"T({shape}, {dtype})"
105+
106+
107+
def _serialize_value(value):
108+
"""Helper function to serialize any value (tensor, list, primitive)"""
109+
if isinstance(value, torch.Tensor):
110+
return _serialize_tensor(value)
111+
elif isinstance(value, list):
112+
list_parts = [_serialize_value(item) for item in value]
113+
return f"[{', '.join(list_parts)}]"
114+
else:
115+
return repr(value)
116+
117+
118+
def serialize_args(args, kwargs) -> str:
119+
"""Convert args and kwargs back to the BackendBench string format
120+
121+
Args:
122+
args: List of arguments (can contain tensors, lists, primitives)
123+
kwargs: Dict of keyword arguments
124+
125+
Returns:
126+
Serialized string in format: (arg1, arg2, ..., key1=val1, key2=val2, ...)
127+
"""
128+
if args is None or kwargs is None:
129+
return "None"
130+
131+
# Process positional arguments
132+
parts = [_serialize_value(arg) for arg in args]
133+
134+
# Process keyword arguments
135+
kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()]
136+
137+
# Handle empty args tuple properly
138+
args_str = f"({', '.join(parts)},)" if parts else "()"
139+
140+
return f"({args_str}, {{{', '.join(kwargs_parts)}}})"
141+
142+
143+
def deserialize_args(inps):
144+
inps = inps.strip().strip("'")
145+
global_vals = {
146+
"T": _deserialize_tensor,
147+
"th": torch,
148+
"inf": math.inf,
149+
"torch": torch,
150+
**dtype_abbrs_parsing,
151+
}
152+
# f strings introduce quotations we dont want
153+
for key in dtype_abbrs_parsing:
154+
inps = inps.replace(f"'{key}'", key)
155+
return eval(inps.strip().strip("'").strip('"'), global_vals)

0 commit comments

Comments
 (0)