We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent de838dd commit d0ae0fdCopy full SHA for d0ae0fd
graph_net/torch/utils.py
@@ -228,11 +228,13 @@ def convert_meta_classes_to_tensors(file_path):
228
if isinstance(attrs.get("data"), str):
229
raise ValueError("Unimplemented")
230
else:
231
- data_value = torch.tensor(attrs["data"], dtype=data_type).reshape(shape)
+ data_value = torch.tensor(attrs["data"], dtype=data_type).reshape(
232
+ attrs.get("shape"), []
233
+ )
234
235
yield {
236
"info": {
- "shape": shape,
237
+ "shape": attrs.get("shape", []),
238
"dtype": data_type,
239
"device": attrs.get("device", "cpu"),
240
"mean": attrs.get("mean", 0.0),
0 commit comments