Skip to content

Commit 0a7bb79

Browse files
committed
Update utils
1 parent 952e52f commit 0a7bb79

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

graph_net/torch/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def format_data(data):
124124
return "None"
125125
elif isinstance(data, torch.Tensor):
126126
if data.dtype.is_floating_point:
127-
return "[{}]".format(", ".join(f"{x:.6f}" for x in data.flatten().tolist()))
127+
return "[{}]".format(
128+
", ".join(f"{x:.6f}" for x in data.flatten().tolist())
129+
)
128130
else:
129131
return "[{}]".format(", ".join(f"{x}" for x in data.flatten().tolist()))
130132
else:
@@ -137,7 +139,7 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
137139
sparse_values = None
138140

139141
if is_sparse:
140-
data_list = None # No dense data for sparse tensors
142+
data_list = None # No dense data for sparse tensors
141143
sparse_indices = tensor_info["data"]["indices"]
142144
sparse_values = tensor_info["data"]["values"]
143145
elif "input_" in tensor_info["name"]:
@@ -146,7 +148,7 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
146148
else:
147149
pass
148150
else:
149-
if tensor_info["type"] == "small_int_tensor":
151+
if tensor_info["type"] == "small_int_tensor":
150152
data_list = tensor_info["data"].flatten()
151153

152154
info = tensor_info.get("info", {})
@@ -156,7 +158,7 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
156158
mean = info.get("mean", 0.0)
157159
std = info.get("std", 1.0)
158160
uid = f"{name_prefix}_tensor_meta_{tensor_info.get('name', '')}"
159-
161+
160162
lines = [
161163
(f"class {uid}:"),
162164
(f"\tname = \"{tensor_info.get('name', '')}\""),
@@ -172,11 +174,10 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
172174
lines.append(f"\tvalues = {format_data(sparse_values)}")
173175
else:
174176
lines.append(f"\tdata = {format_data(data_list)}")
175-
177+
176178
lines.append("")
177179
return lines
178180

179-
180181
input_infos = converted["input_info"]
181182
if isinstance(input_infos, dict):
182183
input_infos = [input_infos]
@@ -224,7 +225,7 @@ def convert_meta_classes_to_tensors(file_path):
224225
}
225226
data_value = None
226227
data_type = getattr(torch, attrs.get("dtype", "torch.float").split(".")[-1])
227-
228+
228229
if attrs.get("is_sparse"):
229230
indices_shape = (len(attrs.get("shape")), -1)
230231
indices = torch.tensor(attrs["indices"]).reshape(indices_shape)

0 commit comments

Comments
 (0)