Skip to content

Commit d0ae0fd

Browse files
committed
Update Utils
1 parent de838dd commit d0ae0fd

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

graph_net/torch/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,13 @@ def convert_meta_classes_to_tensors(file_path):
228228
if isinstance(attrs.get("data"), str):
229229
raise ValueError("Unimplemented")
230230
else:
231-
data_value = torch.tensor(attrs["data"], dtype=data_type).reshape(shape)
231+
data_value = torch.tensor(attrs["data"], dtype=data_type).reshape(
232+
attrs.get("shape"), []
233+
)
232234

233235
yield {
234236
"info": {
235-
"shape": shape,
237+
"shape": attrs.get("shape", []),
236238
"dtype": data_type,
237239
"device": attrs.get("device", "cpu"),
238240
"mean": attrs.get("mean", 0.0),

0 commit comments

Comments
 (0)