Skip to content

Commit 2ddb2a5

Browse files
authored
[Bug Fix] Handle Big Int Tensor (#158)
* CONTRIBUTE_TUTORIAL_cn.md * Handle big int tensors by converting to sparse COO * Update utils * Update utils * Update utils * Update Utils * Update Utils * Update Utils * Update Utils * Update Utils
1 parent 3b20205 commit 2ddb2a5

File tree

1 file changed

+52
-29
lines changed

1 file changed

+52
-29
lines changed

graph_net/torch/utils.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ def process_tensor(tensor):
5959
"info": info,
6060
}
6161
else:
62-
return {"type": "big_int_tensor", "data": tensor.clone(), "info": info}
62+
return {
63+
"type": "big_int_tensor_by_range",
64+
"min_val": tensor.min().item(),
65+
"max_val": tensor.max().item(),
66+
"info": info,
67+
}
6368
elif tensor.numel() < 1024:
6469
return {"type": "small_tensor", "data": tensor.clone(), "info": info}
6570
else:
@@ -73,16 +78,25 @@ def process_tensor(tensor):
7378
processed_inputs = {"type": "unknown", "value": example_inputs}
7479

7580
def handle_named_tensors(tensor):
76-
data_value = None
77-
data_type = "random_tensor"
81+
info = tensor_info(tensor)
7882
if tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
7983
if tensor.numel() < 1024:
80-
data_type = "small_int_tensor"
81-
data_value = tensor.clone()
84+
return {
85+
"info": info,
86+
"data": tensor.clone(),
87+
"type": "small_int_tensor",
88+
}
8289
else:
83-
data_type = "big_int_tensor"
84-
info = tensor_info(tensor)
85-
return {"info": info, "data": data_value, "type": data_type}
90+
return {
91+
"info": info,
92+
"min_val": tensor.min().item(),
93+
"max_val": tensor.max().item(),
94+
"type": "big_int_tensor_by_range",
95+
}
96+
if tensor.numel() < 1024:
97+
return {"info": info, "data": tensor.clone(), "type": "small_tensor"}
98+
else:
99+
return {"info": info, "data": None, "type": "random_tensor"}
86100

87101
processed_weights = {
88102
key: handle_named_tensors(tensor) for key, tensor in state_dict.items()
@@ -114,46 +128,46 @@ def format_data(data):
114128
return "None"
115129
elif isinstance(data, torch.Tensor):
116130
if data.dtype.is_floating_point:
117-
return "[{}]".format(", ".join(f"{x:.6f}" for x in data.tolist()))
131+
return "[{}]".format(
132+
", ".join(f"{x:.6f}" for x in data.flatten().tolist())
133+
)
118134
else:
119-
return "[{}]".format(", ".join(f"{x}" for x in data.tolist()))
135+
return "[{}]".format(", ".join(f"{x}" for x in data.flatten().tolist()))
120136
else:
121137
return repr(data)
122138

123139
def process_tensor_info(tensor_info, name_prefix="example_input"):
124-
data_list = None
125-
if "input_" in tensor_info["name"]:
126-
if tensor_info["type"] in ["small_tensor", "small_int_tensor"]:
127-
data_list = tensor_info["data"].flatten()
128-
elif tensor_info["type"] == "big_int_tensor":
129-
data_list = f"pt-filename:xxx-key"
130-
else:
131-
pass
132-
else:
133-
if tensor_info["type"] == "small_int_tensor":
134-
data_list = tensor_info["data"].flatten()
135-
if tensor_info["type"] == "big_int_tensor":
136-
raise ValueError(
137-
"Unexpected cases: there are weights in big tensor of int type "
138-
)
140+
tensor_type = tensor_info.get("type")
139141
info = tensor_info.get("info", {})
140142
dtype = info.get("dtype", "torch.float")
141143
shape = info.get("shape", [])
142144
device = info.get("device", "cpu")
143145
mean = info.get("mean", 0.0)
144146
std = info.get("std", 1.0)
145147
uid = f"{name_prefix}_tensor_meta_{tensor_info.get('name', '')}"
146-
return [
148+
149+
lines = [
147150
(f"class {uid}:"),
148151
(f"\tname = \"{tensor_info.get('name', '')}\""),
149152
(f"\tshape = {shape}"),
150153
(f'\tdtype = "{dtype}"'),
151154
(f'\tdevice = "{device}"'),
152155
(f"\tmean = {get_limited_precision_float_str(mean)}"),
153156
(f"\tstd = {get_limited_precision_float_str(std)}"),
154-
(f"\tdata = {format_data(data_list)}"),
155-
(""),
156157
]
158+
if tensor_type == "big_int_tensor_by_range":
159+
lines.append(f"\tmin_val = {tensor_info['min_val']}")
160+
lines.append(f"\tmax_val = {tensor_info['max_val']}")
161+
elif "data" in tensor_info:
162+
data_list = (
163+
tensor_info["data"].flatten()
164+
if isinstance(tensor_info["data"], torch.Tensor)
165+
else tensor_info["data"]
166+
)
167+
lines.append(f"\tdata = {format_data(data_list)}")
168+
169+
lines.append("")
170+
return lines
157171

158172
input_infos = converted["input_info"]
159173
if isinstance(input_infos, dict):
@@ -202,7 +216,16 @@ def convert_meta_classes_to_tensors(file_path):
202216
}
203217
data_value = None
204218
data_type = getattr(torch, attrs.get("dtype", "torch.float").split(".")[-1])
205-
if attrs.get("data") is not None:
219+
shape = attrs.get("shape", [])
220+
221+
if "min_val" in attrs and "max_val" in attrs:
222+
min_val = attrs["min_val"]
223+
max_val = attrs["max_val"]
224+
# torch.randint's upper bound is exclusive, so add 1
225+
data_value = torch.randint(
226+
min_val, max_val + 1, size=shape, dtype=data_type
227+
)
228+
elif attrs.get("data") is not None:
206229
if isinstance(attrs.get("data"), str):
207230
raise ValueError("Unimplemented")
208231
else:

0 commit comments

Comments
 (0)