Skip to content

Commit 9515e47

Browse files
feat: Configurable precision (#19)
Signed-off-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
1 parent 20af897 commit 9515e47

File tree

7 files changed

+16
-3
lines changed

7 files changed

+16
-3
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
grpo:
33
num_prompts_per_step: 8
44
num_generations_per_prompt: 8
5-
max_num_steps: 100
5+
max_num_steps: 1000000
66
normalize_rewards: true
77
use_leave_one_out_baseline: true
88
val_period: 10
@@ -30,6 +30,7 @@ policy:
3030
learning_rate: 5.0e-6
3131
logprob_batch_size: 4
3232
max_total_sequence_length: 512
33+
precision: "bfloat16"
3334

3435
scheduler:
3536
- name: "torch.optim.lr_scheduler.LinearLR"

examples/configs/grpo_math_8B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ policy:
99
learning_rate: 5.0e-6
1010
logprob_batch_size: 2
1111
max_total_sequence_length: 4096
12+
precision: "bfloat16"
1213

1314
scheduler:
1415
- name: "torch.optim.lr_scheduler.LinearLR"

examples/configs/sft.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ policy:
2121
train_micro_batch_size: 2
2222
learning_rate: 5.0e-6
2323
max_total_sequence_length: 1024
24+
precision: "float32"
2425

2526
scheduler:
2627
- name: "torch.optim.lr_scheduler.LinearLR"

nemo_reinforcer/models/policy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ class PolicyConfig(TypedDict):
2424
learning_rate: float
2525
logprob_batch_size: int
2626
generation: GenerationConfig
27+
precision: str

nemo_reinforcer/models/policy/hf_policy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,23 @@ def __init__(
7272
rank = torch.distributed.get_rank()
7373
world_size = torch.distributed.get_world_size()
7474
model_name = self.cfg["model_name"]
75+
if self.cfg["precision"] == "float32":
76+
dtype = torch.float32
77+
elif self.cfg["precision"] == "bfloat16":
78+
dtype = torch.bfloat16
79+
else:
80+
raise ValueError(f"Unknown precision: {self.cfg['precision']}")
7581

7682
print(f"[Rank {rank}] Loading model {model_name} on CPU...")
7783
self.model = AutoModelForCausalLM.from_pretrained(
7884
model_name,
7985
device_map="cpu", # load weights onto CPU initially
80-
torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
86+
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
8187
)
8288
self.reference_model = AutoModelForCausalLM.from_pretrained(
8389
model_name,
8490
device_map="cpu", # load weights onto CPU initially
85-
torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
91+
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
8692
)
8793

8894
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
191191
"logprob_batch_size": 1,
192192
"max_new_tokens": 16,
193193
"do_sample": False,
194+
"precision": "float32",
194195
}
195196

196197
vllm_policy = None
@@ -437,6 +438,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size):
437438
"logprob_batch_size": 1,
438439
"max_new_tokens": 16,
439440
"do_sample": False,
441+
"precision": "float32",
440442
}
441443

442444
hf_policy = HfPolicy(cluster, hf_config)

tests/unit/models/policy/test_hf_ray_policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"train_micro_batch_size": 1,
3333
"learning_rate": 5e-6,
3434
"logprob_batch_size": 1,
35+
"precision": "float32",
3536
"generation": {
3637
"backend": "hf",
3738
"temperature": 1.0,

0 commit comments

Comments
 (0)