Skip to content

Commit b98ecdf

Browse files
author
Nikhil Khandekar
committed
Add RL training configs, scripts, and environment updates
- Add GRPO training configs for MedCalc-Bench, MedMCQA, MedCaseReasoning, and combined training - Add plotting scripts for training curves and evaluation results - Add model evaluation and HuggingFace upload scripts - Update environment files with improvements for RL training
1 parent abf68cd commit b98ecdf

23 files changed

+2622
-13
lines changed

.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
outputs/
22

3+
# Prime-RL (separate repo with trained models)
4+
prime-rl/
5+
6+
# CSV data exports
7+
*.csv
8+
9+
# Generated plots
10+
plots/
11+
12+
# Eval results
13+
eval_results/
14+
315
# Byte-compiled / optimized / DLL files
416
__pycache__/
517
*.py[codz]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
inference_gpu_ids = [0, 1, 2, 3, 4, 5]
2+
trainer_gpu_ids = [6, 7]
3+
4+
max_steps = 300
5+
seq_len = 4096
6+
7+
[wandb]
8+
project = "medcalc-bench-verified-8b"
9+
name = "medcalc-bench-verified-8b-nothink"
10+
11+
[ckpt]
12+
# Checkpoint at the end of training
13+
14+
[model]
15+
name = "Qwen/Qwen3-8B-Instruct-2507"
16+
17+
[orchestrator]
18+
batch_size = 128 # Reduced for 8B model
19+
rollouts_per_example = 16
20+
21+
[[orchestrator.env]]
22+
id = "medcalc_bench"
23+
24+
# Environment arguments for MedCalc-Bench
25+
[orchestrator.env.args]
26+
one_shot = false
27+
add_python_tool = false
28+
use_think = false # No thinking tags
29+
answer_format = "xml"
30+
use_verified_dataset = true # Use MedCalc-Bench-Verified
31+
32+
[orchestrator.sampling]
33+
max_tokens = 1024
34+
temperature = 1.0
35+
36+
# Evaluation configuration
37+
[orchestrator.eval]
38+
interval = 10
39+
eval_base_model = true
40+
41+
[[orchestrator.eval.env]]
42+
id = "medcalc_bench"
43+
num_examples = 300
44+
rollouts_per_example = 1
45+
46+
[orchestrator.eval.env.args]
47+
one_shot = false
48+
add_python_tool = false
49+
use_think = false # No thinking tags
50+
answer_format = "xml"
51+
use_verified_dataset = true # Use MedCalc-Bench-Verified
52+
53+
[orchestrator.eval.sampling]
54+
temperature = 0.0
55+
max_tokens = 1024
56+
57+
[trainer]
58+
# Default trainer config (GRPO)
59+
60+
[trainer.model.lora]
61+
rank = 32
62+
alpha = 64.0
63+
64+
[trainer.optim]
65+
lr = 1e-5
66+
67+
[inference]
68+
# Inference server config
69+
70+
[inference.parallel]
71+
tp = 2
72+
dp = 3

configs/medcalc_bench_rl.toml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# MedCalc-Bench RL Training Configuration
2+
# Clinical calculator reasoning with numeric/date outputs
3+
4+
inference_gpu_ids = [0, 1, 2, 3, 4, 5]
5+
trainer_gpu_ids = [6, 7]
6+
7+
max_steps = 300
8+
seq_len = 4096
9+
10+
[wandb]
11+
project = "medcalc-bench-rl-1"
12+
name = "medcalc-bench-rl-1"
13+
14+
[ckpt]
15+
# Checkpoint at the end of training
16+
17+
[model]
18+
name = "Qwen/Qwen3-4B-Instruct-2507"
19+
20+
[orchestrator]
21+
batch_size = 512
22+
rollouts_per_example = 16
23+
24+
[[orchestrator.env]]
25+
id = "medcalc_bench"
26+
27+
# Environment arguments for MedCalc-Bench
28+
[orchestrator.env.args]
29+
one_shot = false # Zero-shot prompting
30+
add_python_tool = false # No code execution
31+
use_think = true # Enable <think> reasoning tags
32+
answer_format = "xml" # Use XML format for answers
33+
34+
[orchestrator.sampling]
35+
max_tokens = 2048
36+
temperature = 1.0
37+
38+
# Evaluation configuration
39+
[orchestrator.eval]
40+
interval = 10 # Evaluate every 10 steps
41+
eval_base_model = true # Also evaluate the base model at step 0
42+
43+
[[orchestrator.eval.env]]
44+
id = "medcalc_bench"
45+
num_examples = 300
46+
rollouts_per_example = 1
47+
48+
[orchestrator.eval.env.args]
49+
one_shot = false
50+
add_python_tool = false
51+
use_think = true
52+
answer_format = "xml"
53+
54+
[orchestrator.eval.sampling]
55+
temperature = 0.0 # Greedy sampling for evaluation
56+
max_tokens = 2048
57+
58+
[trainer]
59+
# Default trainer config (GRPO)
60+
61+
[trainer.model.lora]
62+
rank = 32
63+
alpha = 64.0
64+
65+
[trainer.optim]
66+
lr = 1e-5 # LoRA typically benefits from higher LR
67+
68+
[inference]
69+
# Inference server config
70+
71+
[inference.parallel]
72+
tp = 2 # Split model across 2 GPUs for faster per-query latency
73+
dp = 3 # 3 parallel instances
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
inference_gpu_ids = [0, 1, 2, 3, 4, 5]
2+
trainer_gpu_ids = [6, 7]
3+
4+
max_steps = 300
5+
seq_len = 4096
6+
7+
[wandb]
8+
project = "medcalc-bench-verified"
9+
name = "medcalc-bench-verified-2048tok"
10+
11+
[ckpt]
12+
# Checkpoint at the end of training
13+
14+
[model]
15+
name = "Qwen/Qwen3-4B-Instruct-2507"
16+
17+
[orchestrator]
18+
batch_size = 256
19+
rollouts_per_example = 16
20+
21+
[[orchestrator.env]]
22+
id = "medcalc_bench"
23+
24+
# Environment arguments for MedCalc-Bench
25+
[orchestrator.env.args]
26+
one_shot = false
27+
add_python_tool = false
28+
use_think = true
29+
answer_format = "xml"
30+
use_verified_dataset = true # Use MedCalc-Bench-Verified
31+
32+
[orchestrator.sampling]
33+
max_tokens = 1024
34+
temperature = 1.0
35+
36+
# Evaluation configuration
37+
[orchestrator.eval]
38+
interval = 10
39+
eval_base_model = true
40+
41+
[[orchestrator.eval.env]]
42+
id = "medcalc_bench"
43+
num_examples = 300
44+
rollouts_per_example = 1
45+
46+
[orchestrator.eval.env.args]
47+
one_shot = false
48+
add_python_tool = false
49+
use_think = true
50+
answer_format = "xml"
51+
use_verified_dataset = true # Use MedCalc-Bench-Verified
52+
53+
[orchestrator.eval.sampling]
54+
temperature = 0.0
55+
max_tokens = 1024
56+
57+
[trainer]
58+
# Default trainer config (GRPO)
59+
60+
[trainer.model.lora]
61+
rank = 32
62+
alpha = 64.0
63+
64+
[trainer.optim]
65+
lr = 1e-5
66+
67+
[inference]
68+
# Inference server config
69+
70+
[inference.parallel]
71+
tp = 2
72+
dp = 3
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# MedCalc-Bench RL Training Configuration (with Python Tool)
2+
# Clinical calculator reasoning with code execution
3+
4+
inference_gpu_ids = [0, 1, 2, 3, 4, 5]
5+
trainer_gpu_ids = [6, 7]
6+
7+
max_steps = 300
8+
seq_len = 8192 # Longer for multi-turn tool use
9+
10+
[wandb]
11+
project = "medcalc-bench-tools"
12+
name = "medcalc-bench-tools-rl"
13+
14+
[ckpt]
15+
# Checkpoint at the end of training
16+
17+
[model]
18+
name = "Qwen/Qwen3-4B-Instruct-2507"
19+
20+
[orchestrator]
21+
batch_size = 256 # Smaller batch for multi-turn (more tokens per example)
22+
rollouts_per_example = 8
23+
24+
[[orchestrator.env]]
25+
id = "medcalc_bench"
26+
27+
# Environment arguments for MedCalc-Bench with tools
28+
[orchestrator.env.args]
29+
one_shot = false
30+
add_python_tool = true # Enable Python tool for calculations
31+
use_think = true
32+
answer_format = "xml"
33+
max_turns = 10 # Allow multiple tool calls
34+
use_verified_dataset = true # Use MedCalc-Bench-Verified
35+
36+
[orchestrator.sampling]
37+
max_tokens = 2048
38+
temperature = 1.0
39+
40+
# Evaluation configuration
41+
[orchestrator.eval]
42+
interval = 10
43+
eval_base_model = true
44+
45+
[[orchestrator.eval.env]]
46+
id = "medcalc_bench"
47+
num_examples = 300
48+
rollouts_per_example = 1
49+
50+
[orchestrator.eval.env.args]
51+
one_shot = false
52+
add_python_tool = true # Tools enabled for eval too
53+
use_think = true
54+
answer_format = "xml"
55+
max_turns = 10
56+
use_verified_dataset = true # Use MedCalc-Bench-Verified
57+
58+
[orchestrator.eval.sampling]
59+
temperature = 0.0
60+
max_tokens = 2048
61+
62+
[trainer]
63+
# Default trainer config (GRPO)
64+
65+
[trainer.model.lora]
66+
rank = 32
67+
alpha = 64.0
68+
69+
[trainer.optim]
70+
lr = 1e-5
71+
72+
[inference]
73+
# Inference server config
74+
75+
[inference.parallel]
76+
tp = 2
77+
dp = 3

configs/medcasereasoning_rl.toml

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# MedCaseReasoning RL Training Configuration
2+
# Medical diagnosis from case presentations using LLM-as-a-Judge
3+
4+
inference_gpu_ids = [0, 1, 2, 3, 4, 5]
5+
trainer_gpu_ids = [6, 7]
6+
7+
max_steps = 300
8+
seq_len = 2048 # Prompts are shorter (max ~1034 tokens)
9+
10+
[wandb]
11+
project = "medcasereasoning-rl"
12+
name = "medcasereasoning-rl"
13+
14+
[ckpt]
15+
# Checkpoint at the end of training
16+
17+
[model]
18+
name = "Qwen/Qwen3-4B-Instruct-2507"
19+
20+
[orchestrator]
21+
batch_size = 512
22+
rollouts_per_example = 16
23+
24+
[[orchestrator.env]]
25+
id = "medcasereasoning"
26+
27+
# Environment arguments - uses LLM-as-a-Judge
28+
[orchestrator.env.args]
29+
judge_model = "gpt-5-nano"
30+
reasoning_effort = "low"
31+
# judge_base_url = "http://localhost:8001/v1" # Uncomment for local judge
32+
# judge_api_key = "your-api-key" # Optional, uses OPENAI_API_KEY by default
33+
34+
[orchestrator.sampling]
35+
max_tokens = 1024
36+
temperature = 1.0
37+
38+
# Evaluation configuration
39+
[orchestrator.eval]
40+
interval = 10
41+
eval_base_model = true
42+
43+
[[orchestrator.eval.env]]
44+
id = "medcasereasoning"
45+
num_examples = 100 # Smaller eval set (LLM judge is slow/expensive)
46+
rollouts_per_example = 1
47+
48+
[orchestrator.eval.env.args]
49+
judge_model = "gpt-5-nano"
50+
reasoning_effort = "low"
51+
eval_split = "test" # Use test set for evaluation
52+
53+
[orchestrator.eval.sampling]
54+
temperature = 0.0
55+
max_tokens = 1024
56+
57+
[trainer]
58+
# Default trainer config (GRPO)
59+
60+
[trainer.model.lora]
61+
rank = 32
62+
alpha = 64.0
63+
64+
[trainer.optim]
65+
lr = 1e-5
66+
67+
[inference]
68+
# Inference server config
69+
70+
[inference.parallel]
71+
tp = 2
72+
dp = 3

0 commit comments

Comments
 (0)