|
2 | 2 | import inspect
|
3 | 3 | import re
|
4 | 4 | import textwrap
|
| 5 | +import math |
| 6 | +import torch |
| 7 | +from torch.testing import make_tensor |
| 8 | + |
| 9 | +dtype_abbrs = { |
| 10 | + torch.bfloat16: "bf16", |
| 11 | + torch.float64: "f64", |
| 12 | + torch.float32: "f32", |
| 13 | + torch.float16: "f16", |
| 14 | + torch.complex32: "c32", |
| 15 | + torch.complex64: "c64", |
| 16 | + torch.complex128: "c128", |
| 17 | + torch.int8: "i8", |
| 18 | + torch.int16: "i16", |
| 19 | + torch.int32: "i32", |
| 20 | + torch.int64: "i64", |
| 21 | + torch.bool: "b8", |
| 22 | + torch.uint8: "u8", |
| 23 | +} |
| 24 | + |
| 25 | +dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()} |
| 26 | + |
| 27 | +_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] |
5 | 28 |
|
6 | 29 |
|
7 | 30 | def uses_cuda_stream(func) -> bool:
|
@@ -51,3 +74,82 @@ def visit_Call(self, node):
|
51 | 74 | finder = StreamCreationFinder()
|
52 | 75 | finder.visit(tree)
|
53 | 76 | return finder.found
|
| 77 | + |
| 78 | + |
| 79 | +def _deserialize_tensor(size, dtype, stride=None, device="cuda"): |
| 80 | + kwargs = {} |
| 81 | + if dtype in _FLOATING_TYPES: |
| 82 | + kwargs.update({"low": 0, "high": 1}) |
| 83 | + |
| 84 | + # Fall back to CPU if CUDA is not available |
| 85 | + if device == "cuda" and not torch.cuda.is_available(): |
| 86 | + device = "cpu" |
| 87 | + |
| 88 | + if stride is not None: |
| 89 | + extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) |
| 90 | + data = make_tensor(extent, dtype=dtype, device=device, **kwargs) |
| 91 | + return data.as_strided(size, stride) |
| 92 | + return make_tensor(size, dtype=dtype, device=device, **kwargs) |
| 93 | + |
| 94 | + |
| 95 | +def _serialize_tensor(tensor): |
| 96 | + """Helper function to serialize a tensor to string format""" |
| 97 | + shape = list(tensor.shape) |
| 98 | + dtype = dtype_abbrs[tensor.dtype] |
| 99 | + stride = tensor.stride() if not tensor.is_contiguous() else None |
| 100 | + |
| 101 | + if stride: |
| 102 | + return f"T({shape}, {dtype}, {list(stride)})" |
| 103 | + else: |
| 104 | + return f"T({shape}, {dtype})" |
| 105 | + |
| 106 | + |
| 107 | +def _serialize_value(value): |
| 108 | + """Helper function to serialize any value (tensor, list, primitive)""" |
| 109 | + if isinstance(value, torch.Tensor): |
| 110 | + return _serialize_tensor(value) |
| 111 | + elif isinstance(value, list): |
| 112 | + list_parts = [_serialize_value(item) for item in value] |
| 113 | + return f"[{', '.join(list_parts)}]" |
| 114 | + else: |
| 115 | + return repr(value) |
| 116 | + |
| 117 | + |
| 118 | +def serialize_args(args, kwargs) -> str: |
| 119 | + """Convert args and kwargs back to the BackendBench string format |
| 120 | +
|
| 121 | + Args: |
| 122 | + args: List of arguments (can contain tensors, lists, primitives) |
| 123 | + kwargs: Dict of keyword arguments |
| 124 | +
|
| 125 | + Returns: |
| 126 | + Serialized string in format: (arg1, arg2, ..., key1=val1, key2=val2, ...) |
| 127 | + """ |
| 128 | + if args is None or kwargs is None: |
| 129 | + return "None" |
| 130 | + |
| 131 | + # Process positional arguments |
| 132 | + parts = [_serialize_value(arg) for arg in args] |
| 133 | + |
| 134 | + # Process keyword arguments |
| 135 | + kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()] |
| 136 | + |
| 137 | + # Handle empty args tuple properly |
| 138 | + args_str = f"({', '.join(parts)},)" if parts else "()" |
| 139 | + |
| 140 | + return f"({args_str}, {{{', '.join(kwargs_parts)}}})" |
| 141 | + |
| 142 | + |
| 143 | +def deserialize_args(inps): |
| 144 | + inps = inps.strip().strip("'") |
| 145 | + global_vals = { |
| 146 | + "T": _deserialize_tensor, |
| 147 | + "th": torch, |
| 148 | + "inf": math.inf, |
| 149 | + "torch": torch, |
| 150 | + **dtype_abbrs_parsing, |
| 151 | + } |
| 152 | + # f strings introduce quotations we dont want |
| 153 | + for key in dtype_abbrs_parsing: |
| 154 | + inps = inps.replace(f"'{key}'", key) |
| 155 | + return eval(inps.strip().strip("'").strip('"'), global_vals) |
0 commit comments