Skip to content

Commit d88fb1b

Browse files
committed
feat: implement MultiDtypeGenerator for low-precision sample generation
- Add MultiDtypeGenerator class to generate float16/bfloat16/float8 samples from float32 samples - Automatically convert input and weight tensor dtypes - Keep batch_norm parameters (running_mean, running_var, scale, bias) as float32 - Add torch.autocast context manager to model.py forward method - Update graph_net.json metadata with dtype information - Add sample validation mechanism - Add MultiDtypeFilter for filtering invalid graphs - Add command-line tool for batch generation - Support test_compiler evaluation and Agent code generation workflows
1 parent 47424a2 commit d88fb1b

File tree

2 files changed

+573
-0
lines changed

2 files changed

+573
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Command-line tool for MultiDtypeGenerator.
3+
4+
Usage:
5+
python -m graph_net.torch.multi_dtype_generate \
6+
--source-sample-path /path/to/source/sample \
7+
--output-base-path /path/to/output \
8+
--dtype-list float16 bfloat16
9+
"""
10+
11+
import argparse
12+
import sys
13+
from pathlib import Path
14+
from graph_net.torch.multi_dtype_generator import MultiDtypeGenerator
15+
16+
17+
def main():
18+
parser = argparse.ArgumentParser(
19+
description='Generate low-precision samples from existing float32 samples'
20+
)
21+
parser.add_argument(
22+
'--source-sample-path',
23+
type=str,
24+
required=True,
25+
help='Path to the source sample directory'
26+
)
27+
parser.add_argument(
28+
'--output-base-path',
29+
type=str,
30+
required=True,
31+
help='Base path for output samples'
32+
)
33+
parser.add_argument(
34+
'--dtype-list',
35+
type=str,
36+
nargs='+',
37+
default=['float16', 'bfloat16'],
38+
choices=['float16', 'bfloat16', 'float8'],
39+
help='List of target dtypes to generate (default: float16 bfloat16)'
40+
)
41+
parser.add_argument(
42+
'--filter-config',
43+
type=str,
44+
default=None,
45+
help='Path to filter configuration file (JSON format)'
46+
)
47+
48+
args = parser.parse_args()
49+
50+
# Load filter config if provided
51+
filter_config = {}
52+
if args.filter_config:
53+
import json
54+
with open(args.filter_config, 'r') as f:
55+
filter_config = json.load(f)
56+
57+
# Create generator
58+
generator = MultiDtypeGenerator(
59+
source_sample_path=args.source_sample_path,
60+
output_base_path=args.output_base_path,
61+
dtype_list=args.dtype_list,
62+
filter_config=filter_config,
63+
)
64+
65+
# Generate samples
66+
print(f"Generating samples from: {args.source_sample_path}")
67+
print(f"Output base path: {args.output_base_path}")
68+
print(f"Target dtypes: {args.dtype_list}")
69+
70+
generated_paths = generator.generate()
71+
72+
if generated_paths:
73+
print(f"\nSuccessfully generated {len(generated_paths)} sample(s):")
74+
for path in generated_paths:
75+
print(f" - {path}")
76+
return 0
77+
else:
78+
print("\nNo samples were generated successfully.")
79+
return 1
80+
81+
82+
if __name__ == '__main__':
83+
sys.exit(main())
84+

0 commit comments

Comments
 (0)