Skip to content

Commit 7a90a59

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/GraphNet into cosyvoice_part1
2 parents a519c7a + 6053442 commit 7a90a59

File tree

6,135 files changed

+6255202
-33
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

6,135 files changed

+6255202
-33
lines changed

graph_net/paddle/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def convert_meta_classes_to_tensors(file_path):
132132
if isinstance(attrs.get("data"), str):
133133
raise ValueError("Unimplemented")
134134
else:
135-
data_value = paddle.tensor(attrs["data"], dtype=data_type).reshape(
136-
attrs.get("shape"), []
137-
)
135+
data_value = paddle.to_tensor(
136+
attrs.get("data"), dtype=data_type
137+
).reshape(attrs.get("shape"), [])
138138
yield {
139139
"info": {
140140
"shape": attrs.get("shape", []),
@@ -170,4 +170,4 @@ def replay_tensor(info):
170170
if "data" in info and info["data"] is not None:
171171
return info["data"].to(device)
172172

173-
return paddle.randn(shape).to(dtype).to(device) * std * 1e-3 + 1e-2
173+
return (paddle.randn(shape).cast(dtype).to(device) * std * 1e-3 + 1e-2).cast(dtype)

graph_net/paddle/validate.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import graph_net
1313
import os
1414
import re
15+
import paddle
1516

1617

1718
def load_class_from_file(file_path: str, class_name: str):
@@ -68,7 +69,15 @@ def main(args):
6869
y = model(**state_dict)[0]
6970

7071
print(np.argmin(y), np.argmax(y))
71-
print(y.shape)
72+
if isinstance(y, paddle.Tensor):
73+
print(y.shape)
74+
elif (isinstance(y, list) or isinstance(y, tuple)) and all(
75+
isinstance(obj, paddle.Tensor) for obj in y
76+
):
77+
# list of paddle.Tensor
78+
print(y[0].shape)
79+
else:
80+
raise ValueError("Illegal return value.")
7281

7382
if not args.no_check_redundancy:
7483
print("Check redundancy ...")

graph_net/torch/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,6 @@ def replay_tensor(info):
265265
std = info["info"]["std"]
266266
if "data" in info and info["data"] is not None:
267267
return info["data"].to(device)
268+
if dtype is torch.bool:
269+
return (torch.randn(size=shape) > 0.5).to(dtype).to(device)
268270
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
dda77fde4b537ec1947ab7ca772a7e22148567438f51c8140c4ad22659333583
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false
6+
}

samples/mmseg/DeiT_B/input_meta.py

Whitespace-only changes.

samples/mmseg/DeiT_B/input_tensor_constraints.py

Whitespace-only changes.

samples/mmseg/DeiT_B/model.py

Lines changed: 1759 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)