|
22 | 22 | {%- endif -%} |
23 | 23 | {{"\n"}} |
24 | 24 | import torch |
25 | | -from torch import device |
| 25 | +from torch import device, inf |
26 | 26 |
|
27 | 27 |
|
28 | 28 | {% macro get_input_tensor_instance(tensor_meta, device) -%} |
|
36 | 36 | {%- if data is not none -%} |
37 | 37 | torch.tensor({{data}}, dtype={{dtype}}).reshape({{shape}}).to(device='{{device}}') |
38 | 38 | {%- elif dtype == "torch.bool" -%} |
39 | | - torch.rand({{shape}}, device={{device}}) > 0.5 |
| 39 | + torch.rand({{shape}}, device='{{device}}') > 0.5 |
40 | 40 | {%- elif dtype in ["torch.int8", "torch.int16", "torch.int32", "torch.int64"] -%} |
41 | 41 | torch.randint({{min_val}}, {{max_val}} + 1, size={{shape}}, dtype={{dtype}}).to(device='{{device}}') |
42 | 42 | {%- elif dtype in ["torch.float16", "torch.bfloat16", "torch.float32", "torch.float64"] -%} |
@@ -118,7 +118,7 @@ def test_main(self): |
118 | 118 | {%- if data is not none -%} |
119 | 119 | paddle.to_tensor({{data}}, dtype='{{dtype}}', shape={{shape}}).to(device='{{device}}') |
120 | 120 | {%- elif dtype == "bool" -%} |
121 | | - paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}') |
| 121 | + paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') |
122 | 122 | {%- elif dtype in ["int8", "int16", "int32", "int64"] -%} |
123 | 123 | paddle.randint(low={{min_val}}, high={{max_val}} + 1, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') |
124 | 124 | {%- elif dtype in ["float16", "bfloat16", "float32", "float64"] -%} |
@@ -456,13 +456,8 @@ def __call__(self, rel_model_path: str): |
456 | 456 | self.resumable_handle_sample(rel_model_path) |
457 | 457 |
|
458 | 458 | def sample_handled(self, rel_model_path: str) -> bool: |
459 | | - dst_model_path = Path(self.config["output_dir"]) / rel_model_path |
460 | | - if not dst_model_path.exists(): |
461 | | - return False |
462 | 459 | output_name = self._get_output_name(rel_model_path) |
463 | | - num_model_py_files = len(list(dst_model_path.rglob(output_name))) |
464 | | - assert num_model_py_files <= 1 |
465 | | - return num_model_py_files == 1 |
| 460 | + return self.naive_sample_handled(rel_model_path, search_file_name=output_name) |
466 | 461 |
|
467 | 462 | def _get_output_name(self, rel_model_path: str): |
468 | 463 | return f"{Path(rel_model_path).name}_test.py" |
|
0 commit comments