Skip to content

Commit ddc9ecb

Browse files
authored
[Bug Fix] Fix device and inf in generated unittest. (#478)
* Fix device and inf in generated unittest. * Enable try-run in renamer.
1 parent 9297cf6 commit ddc9ecb

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

graph_net/sample_pass/agent_unittest_generator.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
{%- endif -%}
2323
{{"\n"}}
2424
import torch
25-
from torch import device
25+
from torch import device, inf
2626
2727
2828
{% macro get_input_tensor_instance(tensor_meta, device) -%}
@@ -36,7 +36,7 @@
3636
{%- if data is not none -%}
3737
torch.tensor({{data}}, dtype={{dtype}}).reshape({{shape}}).to(device='{{device}}')
3838
{%- elif dtype == "torch.bool" -%}
39-
torch.rand({{shape}}, device={{device}}) > 0.5
39+
torch.rand({{shape}}, device='{{device}}') > 0.5
4040
{%- elif dtype in ["torch.int8", "torch.int16", "torch.int32", "torch.int64"] -%}
4141
torch.randint({{min_val}}, {{max_val}} + 1, size={{shape}}, dtype={{dtype}}).to(device='{{device}}')
4242
{%- elif dtype in ["torch.float16", "torch.bfloat16", "torch.float32", "torch.float64"] -%}
@@ -118,7 +118,7 @@ def test_main(self):
118118
{%- if data is not none -%}
119119
paddle.to_tensor({{data}}, dtype='{{dtype}}', shape={{shape}}).to(device='{{device}}')
120120
{%- elif dtype == "bool" -%}
121-
paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}')
121+
paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}')
122122
{%- elif dtype in ["int8", "int16", "int32", "int64"] -%}
123123
paddle.randint(low={{min_val}}, high={{max_val}} + 1, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}')
124124
{%- elif dtype in ["float16", "bfloat16", "float32", "float64"] -%}
@@ -456,13 +456,8 @@ def __call__(self, rel_model_path: str):
456456
self.resumable_handle_sample(rel_model_path)
457457

458458
def sample_handled(self, rel_model_path: str) -> bool:
459-
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
460-
if not dst_model_path.exists():
461-
return False
462459
output_name = self._get_output_name(rel_model_path)
463-
num_model_py_files = len(list(dst_model_path.rglob(output_name)))
464-
assert num_model_py_files <= 1
465-
return num_model_py_files == 1
460+
return self.naive_sample_handled(rel_model_path, search_file_name=output_name)
466461

467462
def _get_output_name(self, rel_model_path: str):
468463
return f"{Path(rel_model_path).name}_test.py"

graph_net/torch/graph_variable_renamer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ def __call__(self, rel_model_path):
102102
src_model_path, temp_model_path, rename_map
103103
)
104104
self._update_input_meta_py_file(src_model_path, temp_model_path, rename_map)
105-
# print("Try to run renamed model...")
106-
# self._try_run(temp_model_path)
105+
self._try_run(temp_model_path)
107106
shutil.copytree(temp_model_path, dst_model_path)
108107

109108
def _try_run(self, model_path):
109+
print(f"[GraphVariableRenamer] Try to run {model_path}")
110110
assert self.model_runnable_predicator(
111111
model_path
112112
), f"{model_path} is not a runnable model"

0 commit comments

Comments
 (0)