Skip to content

Commit 995d7bf

Browse files
committed
refactor: optimize MultiDtypeGenerator code style and reduce nesting
- Refactor _modify_weight_meta to use single-pass traversal - Extract _should_keep_weight_float32 helper method - Simplify _ensure_autocast_import to reduce duplicate file writes - Clean up unused variables in _validate_sample - Apply black code formatting
1 parent d88fb1b commit 995d7bf

File tree

2 files changed

+279
-304
lines changed

2 files changed

+279
-304
lines changed

graph_net/torch/multi_dtype_generate.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,59 +16,60 @@
1616

1717
def main():
1818
parser = argparse.ArgumentParser(
19-
description='Generate low-precision samples from existing float32 samples'
19+
description="Generate low-precision samples from existing float32 samples"
2020
)
2121
parser.add_argument(
22-
'--source-sample-path',
22+
"--source-sample-path",
2323
type=str,
2424
required=True,
25-
help='Path to the source sample directory'
25+
help="Path to the source sample directory",
2626
)
2727
parser.add_argument(
28-
'--output-base-path',
28+
"--output-base-path",
2929
type=str,
3030
required=True,
31-
help='Base path for output samples'
31+
help="Base path for output samples",
3232
)
3333
parser.add_argument(
34-
'--dtype-list',
34+
"--dtype-list",
3535
type=str,
36-
nargs='+',
37-
default=['float16', 'bfloat16'],
38-
choices=['float16', 'bfloat16', 'float8'],
39-
help='List of target dtypes to generate (default: float16 bfloat16)'
36+
nargs="+",
37+
default=["float16", "bfloat16"],
38+
choices=["float16", "bfloat16", "float8"],
39+
help="List of target dtypes to generate (default: float16 bfloat16)",
4040
)
4141
parser.add_argument(
42-
'--filter-config',
42+
"--filter-config",
4343
type=str,
4444
default=None,
45-
help='Path to filter configuration file (JSON format)'
45+
help="Path to filter configuration file (JSON format)",
4646
)
47-
47+
4848
args = parser.parse_args()
49-
49+
5050
# Load filter config if provided
5151
filter_config = {}
5252
if args.filter_config:
5353
import json
54-
with open(args.filter_config, 'r') as f:
54+
55+
with open(args.filter_config, "r") as f:
5556
filter_config = json.load(f)
56-
57+
5758
# Create generator
5859
generator = MultiDtypeGenerator(
5960
source_sample_path=args.source_sample_path,
6061
output_base_path=args.output_base_path,
6162
dtype_list=args.dtype_list,
6263
filter_config=filter_config,
6364
)
64-
65+
6566
# Generate samples
6667
print(f"Generating samples from: {args.source_sample_path}")
6768
print(f"Output base path: {args.output_base_path}")
6869
print(f"Target dtypes: {args.dtype_list}")
69-
70+
7071
generated_paths = generator.generate()
71-
72+
7273
if generated_paths:
7374
print(f"\nSuccessfully generated {len(generated_paths)} sample(s):")
7475
for path in generated_paths:
@@ -79,6 +80,5 @@ def main():
7980
return 1
8081

8182

82-
if __name__ == '__main__':
83+
if __name__ == "__main__":
8384
sys.exit(main())
84-

0 commit comments

Comments
 (0)