Skip to content

Commit eb6948b

Browse files
committed
Refine the replay_tensor according to the new dumpped tensor meta.
1 parent 6f7b3a7 commit eb6948b

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def get_input_spec(args):
7676
name = v["name"]
7777
dtype = v["info"]["dtype"]
7878
shape = v["info"]["shape"]
79-
# print(f"-- i: {i}, v: name={name}, shape={shape}, dtype={dtype}")
8079
input_spec[i] = paddle.static.InputSpec(shape, dtype)
8180
return input_spec
8281

@@ -95,9 +94,6 @@ def test_single_model(args):
9594
synchronizer_func = get_synchronizer_func(args)
9695
input_dict = get_input_dict(args)
9796
model_dy = get_model(args)
98-
input_spec = get_input_spec(args)
99-
build_strategy = paddle.static.BuildStrategy()
100-
build_strategy.build_cinn_pass = False
10197

10298
# eager
10399
print("-- Run with eager mode")
@@ -110,6 +106,7 @@ def test_single_model(args):
110106

111107
# compiled
112108
print("-- Run with compiled mode")
109+
input_spec = get_input_spec(args)
113110
build_strategy = paddle.static.BuildStrategy()
114111
# build_strategy.build_cinn_pass = True
115112
compiled_model = paddle.jit.to_static(

graph_net/paddle/utils.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import ast
1010
import paddle
1111

12+
kLiteralTensorSize = 64
13+
1214

1315
def get_limited_precision_float_str(value):
1416
if not isinstance(value, float):
@@ -35,15 +37,15 @@ def process_tensor(tensor):
3537

3638
info = tensor_info(tensor)
3739
if tensor.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]:
38-
if tensor.numel() < 1024:
40+
if tensor.numel() < kLiteralTensorSize:
3941
return {
4042
"type": "small_int_tensor",
4143
"data": tensor.clone(),
4244
"info": info,
4345
}
4446
else:
4547
return {"type": "big_int_tensor", "data": tensor.clone(), "info": info}
46-
elif tensor.numel() < 1024:
48+
elif tensor.numel() < kLiteralTensorSize:
4749
return {"type": "small_tensor", "data": tensor.clone(), "info": info}
4850
else:
4951
return {"type": "random_tensor", "info": info}
@@ -141,10 +143,10 @@ def convert_meta_classes_to_tensors(file_path):
141143
"shape": attrs.get("shape", []),
142144
"dtype": data_type,
143145
"device": attrs.get("device", "gpu"),
144-
"mean": attrs.get("mean", 0.0),
145-
"std": attrs.get("std", 1.0),
146-
"low": attrs.get("low", 0),
147-
"high": attrs.get("high", 2),
146+
"mean": 0.0 if attrs.get("mean", None) is None else attrs.get("mean"),
147+
"std": 1.0 if attrs.get("std", None) is None else attrs.get("std"),
148+
"min_val": attrs.get("min_val", 0),
149+
"max_val": attrs.get("max_val", 2),
148150
},
149151
"data": data_value,
150152
"name": attrs.get("name"),
@@ -173,17 +175,18 @@ def replay_tensor(info):
173175
device = info["info"]["device"]
174176
dtype = info["info"]["dtype"]
175177
shape = info["info"]["shape"]
176-
min_value = info["info"]["low"] if "low" in info["info"] else 0
177-
max_value = info["info"]["high"] if "high" in info["info"] else 0.5
178+
179+
mean = info["info"]["mean"]
180+
std = info["info"]["std"]
181+
min_val = info["info"]["min_val"]
182+
max_val = info["info"]["max_val"]
178183
if None in shape:
179184
shape = list(map(lambda i: i if i is not None else 1, shape))
180185
if "data" in info and info["data"] is not None:
181186
return paddle.reshape(info["data"], shape).to(dtype).to(device)
182187
elif dtype == paddle.int32 or dtype == paddle.int64:
183188
return paddle.cast(
184-
paddle.randint(
185-
low=min_value, high=max_value + 1, shape=shape, dtype="int64"
186-
),
189+
paddle.randint(low=min_val, high=max_val + 1, shape=shape, dtype="int64"),
187190
dtype,
188191
).to(device)
189192
elif dtype == paddle.bool:
@@ -194,7 +197,9 @@ def replay_tensor(info):
194197
else:
195198
std = info["info"]["std"]
196199
return (
197-
paddle.uniform(shape, dtype="float32", min=min_value, max=max_value)
200+
paddle.clip(
201+
paddle.normal(shape=shape, mean=mean, std=std), min=min_val, max=max_val
202+
)
198203
.to(dtype)
199204
.to(device)
200205
)

graph_net/paddle/validate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def main(args):
6868

6969
y = model(**state_dict)
7070

71-
# print(np.argmin(y), np.argmax(y))
7271
if isinstance(y, paddle.Tensor):
7372
print(y.shape)
7473
elif isinstance(y, list) or isinstance(y, tuple):

0 commit comments

Comments
 (0)