Skip to content

Commit 60bfaac

Browse files
committed
Handle big int tensors by converting to sparse COO
1 parent fc817d0 commit 60bfaac

File tree

1 file changed

+52
-19
lines changed

1 file changed

+52
-19
lines changed

graph_net/torch/utils.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ def process_tensor(tensor):
5757
"info": info,
5858
}
5959
else:
60-
return {"type": "big_int_tensor", "data": tensor.clone(), "info": info}
60+
sparse_tensor = tensor.to_sparse_coo()
61+
return {
62+
"type": "sparse_int_tensor",
63+
"indices": sparse_tensor.indices().clone(),
64+
"values": sparse_tensor.values().clone(),
65+
"info": info,
66+
}
6167
elif tensor.numel() < 1024:
6268
return {"type": "small_tensor", "data": tensor.clone(), "info": info}
6369
else:
@@ -78,15 +84,20 @@ def handle_named_tensors(tensor):
7884
data_type = "small_int_tensor"
7985
data_value = tensor.clone()
8086
else:
81-
data_type = "big_int_tensor"
87+
data_type = "sparse_int_tensor"
88+
sparse_tensor = tensor.to_sparse_coo()
89+
data_value = {
90+
"indices": sparse_tensor.indices().clone(),
91+
"values": sparse_tensor.values().clone(),
92+
}
93+
8294
info = tensor_info(tensor)
8395
return {"info": info, "data": data_value, "type": data_type}
8496

8597
processed_weights = {
8698
key: handle_named_tensors(tensor) for key, tensor in state_dict.items()
8799
}
88100

89-
# dynamic_shapes = extract_dynamic_shapes(example_inputs)
90101
return {
91102
"input_info": processed_inputs,
92103
"weight_info": processed_weights,
@@ -112,46 +123,59 @@ def format_data(data):
112123
return "None"
113124
elif isinstance(data, torch.Tensor):
114125
if data.dtype.is_floating_point:
115-
return "[{}]".format(", ".join(f"{x:.6f}" for x in data.tolist()))
126+
return "[{}]".format(", ".join(f"{x:.6f}" for x in data.flatten().tolist()))
116127
else:
117-
return "[{}]".format(", ".join(f"{x}" for x in data.tolist()))
128+
return "[{}]".format(", ".join(f"{x}" for x in data.flatten().tolist()))
118129
else:
119130
return repr(data)
120131

121132
def process_tensor_info(tensor_info, name_prefix="example_input"):
122133
data_list = None
123-
if "input_" in tensor_info["name"]:
134+
# MODIFICATION: Handle sparse tensor serialization
135+
is_sparse = tensor_info.get("type") == "sparse_int_tensor"
136+
sparse_indices = None
137+
sparse_values = None
138+
139+
if is_sparse:
140+
data_list = None # No dense data for sparse tensors
141+
sparse_indices = tensor_info["data"]["indices"]
142+
sparse_values = tensor_info["data"]["values"]
143+
elif "input_" in tensor_info["name"]:
124144
if tensor_info["type"] in ["small_tensor", "small_int_tensor"]:
125145
data_list = tensor_info["data"].flatten()
126-
elif tensor_info["type"] == "big_int_tensor":
127-
data_list = f"pt-filename:xxx-key"
128146
else:
129147
pass
130148
else:
131-
if tensor_info["type"] == "small_int_tensor":
149+
if tensor_info["type"] == "small_int_tensor":
132150
data_list = tensor_info["data"].flatten()
133-
if tensor_info["type"] == "big_int_tensor":
134-
raise ValueError(
135-
"Unexpected cases: there are weights in big tensor of int type "
136-
)
151+
137152
info = tensor_info.get("info", {})
138153
dtype = info.get("dtype", "torch.float")
139154
shape = info.get("shape", [])
140155
device = info.get("device", "cpu")
141156
mean = info.get("mean", 0.0)
142157
std = info.get("std", 1.0)
143158
uid = f"{name_prefix}_tensor_meta_{tensor_info.get('name', '')}"
144-
return [
159+
160+
lines = [
145161
(f"class {uid}:"),
146162
(f"\tname = \"{tensor_info.get('name', '')}\""),
147163
(f"\tshape = {shape}"),
148164
(f'\tdtype = "{dtype}"'),
149165
(f'\tdevice = "{device}"'),
150166
(f"\tmean = {get_limited_precision_float_str(mean)}"),
151167
(f"\tstd = {get_limited_precision_float_str(std)}"),
152-
(f"\tdata = {format_data(data_list)}"),
153-
(""),
154168
]
169+
if is_sparse:
170+
lines.append(f"\tis_sparse = True")
171+
lines.append(f"\tindices = {format_data(sparse_indices)}")
172+
lines.append(f"\tvalues = {format_data(sparse_values)}")
173+
else:
174+
lines.append(f"\tdata = {format_data(data_list)}")
175+
176+
lines.append("")
177+
return lines
178+
155179

156180
input_infos = converted["input_info"]
157181
if isinstance(input_infos, dict):
@@ -200,13 +224,22 @@ def convert_meta_classes_to_tensors(file_path):
200224
}
201225
data_value = None
202226
data_type = getattr(torch, attrs.get("dtype", "torch.float").split(".")[-1])
203-
if attrs.get("data") is not None:
227+
228+
# MODIFICATION: Reconstruct sparse tensors during loading
229+
if attrs.get("is_sparse"):
230+
indices_shape = (len(attrs.get("shape")), -1)
231+
indices = torch.tensor(attrs["indices"]).reshape(indices_shape)
232+
values = torch.tensor(attrs["values"], dtype=data_type)
233+
shape = attrs.get("shape")
234+
data_value = torch.sparse_coo_tensor(indices, values, shape).to_dense()
235+
elif attrs.get("data") is not None:
204236
if isinstance(attrs.get("data"), str):
205237
raise ValueError("Unimplemented")
206238
else:
207239
data_value = torch.tensor(attrs["data"], dtype=data_type).reshape(
208-
attrs.get("shape"), []
240+
attrs.get("shape", [])
209241
)
242+
210243
yield {
211244
"info": {
212245
"shape": attrs.get("shape", []),
@@ -240,4 +273,4 @@ def replay_tensor(info):
240273
std = info["info"]["std"]
241274
if "data" in info and info["data"] is not None:
242275
return info["data"].to(device)
243-
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
276+
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean

0 commit comments

Comments
 (0)