|
10 | 10 | parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
11 | 11 | parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
12 | 12 | parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
13 |
| - parser.add_argument("-g", "--num-generations", type=int, default=8) |
14 |
| - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) |
15 |
| - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) |
16 |
| - parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) |
17 |
| - parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) |
18 |
| - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) |
19 |
| - parser.add_argument("-b", "--backend", type=str, default="transformers") |
| 13 | + parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") |
| 14 | + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") |
| 15 | + parser.add_argument( |
| 16 | + "-ibs", |
| 17 | + "--inference-batch-size", |
| 18 | + type=int, |
| 19 | + default=64, |
| 20 | + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", |
| 21 | + ) |
| 22 | + parser.add_argument( |
| 23 | + "-imbs", |
| 24 | + "--inference-microbatch-size", |
| 25 | + type=int, |
| 26 | + default=8, |
| 27 | + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", |
| 28 | + ) |
| 29 | + parser.add_argument( |
| 30 | + "-tbs", |
| 31 | + "--train-batch-size", |
| 32 | + type=int, |
| 33 | + default=32, |
| 34 | + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", |
| 35 | + ) |
| 36 | + parser.add_argument( |
| 37 | + "-tMbs", |
| 38 | + "--train-minibatch-size", |
| 39 | + type=int, |
| 40 | + default=1, |
| 41 | + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", |
| 42 | + ) |
| 43 | + parser.add_argument( |
| 44 | + "-tmbs", |
| 45 | + "--train-microbatch-size", |
| 46 | + type=int, |
| 47 | + default=2, |
| 48 | + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", |
| 49 | + ) |
| 50 | + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) |
20 | 51 | parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
21 | 52 | args = parser.parse_args()
|
22 | 53 |
|
|
29 | 60 | ray.init(address="local", namespace="ray-example")
|
30 | 61 |
|
31 | 62 | inference_model_config = dict(path=args.model)
|
32 |
| - train_model_config = dict( |
33 |
| - path=args.model, |
34 |
| - # use_flash_attention_2=True, |
35 |
| - # use_cache=False |
36 |
| - ) |
| 63 | + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) |
37 | 64 | generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
38 | 65 |
|
39 | 66 | if args.backend == "transformers":
|
|
91 | 118 | generate_config=generate_config,
|
92 | 119 | num_generations=args.num_generations,
|
93 | 120 | train_model_config=train_model_config,
|
94 |
| - plugin_config={}, |
| 121 | + # plugin_config={}, # for zero |
| 122 | + plugin_config={ |
| 123 | + "pp_size": 2, |
| 124 | + "tp_size": 1, |
| 125 | + "microbatch_size": args.train_microbatch_size // 2, |
| 126 | + "zero_stage": 0, |
| 127 | + "max_norm": 1.0, |
| 128 | + }, # for pp |
95 | 129 | inference_backend=args.backend,
|
96 | 130 | master_addr="localhost",
|
97 |
| - master_port=29505, |
| 131 | + master_port=29506, |
98 | 132 | core_algo=args.algo,
|
| 133 | + project_name=args.project, |
99 | 134 | )
|
0 commit comments