Skip to content

Commit 2f13a21

Browse files
committed
important fix pad_id
1 parent 9a696a6 commit 2f13a21

File tree

3 files changed

+296
-42
lines changed

3 files changed

+296
-42
lines changed

apps/julia-grpo/llama3_8b_julia.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# Global configuration
55
group_size: 8 # num_generations from unsloth.py
6-
batch_size: 1 # per_device_train_batch_size from unsloth.py
6+
batch_size: 4 # per_device_train_batch_size from unsloth.py
77
max_req_tokens: 2048 # max_prompt_length from unsloth.py
88
max_res_tokens: 1024 # max_completion_length from unsloth.py
99
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"

apps/julia-grpo/main.py

Lines changed: 160 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,43 @@ def response_tensor(self) -> torch.Tensor:
8282
Policy = Generator
8383

8484

85+
# def collate(
86+
# batches: list[Group],
87+
# ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
88+
# """
89+
# Collates a list of batches into a single batch of inputs and targets.
90+
# Each batch is a list of episodes, and each episode is a dict of tensors.
91+
# """
92+
# inputs = []
93+
# targets = []
94+
# for batch in batches:
95+
# request = [e.request_tensor for e in batch]
96+
# request = torch.stack(request) # [b x s]
97+
98+
# response = [e.response_tensor for e in batch]
99+
# response = torch.stack(response) # [b x s]
100+
101+
# ref_logprobs = [e.ref_logprobs for e in batch]
102+
# ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
103+
104+
# advantages = [e.advantage for e in batch]
105+
# advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
106+
107+
# pad_id = batch[0].pad_id
108+
# mask = torch.ne(response, pad_id)
109+
110+
# input = {"tokens": torch.cat([request, response], dim=1)}
111+
# target = {
112+
# "response": response,
113+
# "ref_logprobs": ref_logprobs,
114+
# "advantages": advantages,
115+
# "padding_mask": mask,
116+
# }
117+
# inputs.append(input)
118+
# targets.append(target)
119+
# return inputs, targets
120+
121+
85122
def collate(
86123
batches: list[Group],
87124
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
@@ -91,21 +128,60 @@ def collate(
91128
"""
92129
inputs = []
93130
targets = []
94-
for batch in batches:
131+
for batch_idx, batch in enumerate(batches):
132+
print(f"[DEBUG] Processing batch {batch_idx}, len={len(batch)}")
133+
95134
request = [e.request_tensor for e in batch]
96135
request = torch.stack(request) # [b x s]
136+
print(f"[DEBUG] request shape: {request.shape}")
97137

98138
response = [e.response_tensor for e in batch]
99139
response = torch.stack(response) # [b x s]
140+
print(f"[DEBUG] response shape: {response.shape}")
100141

101142
ref_logprobs = [e.ref_logprobs for e in batch]
102-
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
143+
ref_logprobs = torch.stack(ref_logprobs) # [b x s]
144+
145+
# Only squeeze the first dimension if it exists and is size 1
146+
# This prevents squeezing the sequence dimension
147+
if ref_logprobs.dim() > 2:
148+
ref_logprobs = ref_logprobs.squeeze(0)
149+
print(f"[DEBUG] ref_logprobs shape after stack: {ref_logprobs.shape}")
103150

104151
advantages = [e.advantage for e in batch]
105152
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
153+
print(f"[DEBUG] advantages shape: {advantages.shape}")
106154

107155
pad_id = batch[0].pad_id
108-
mask = response != pad_id
156+
157+
# Ensure mask is always a 2D tensor [b x s], even for single batch elements
158+
mask = torch.ne(response, pad_id) # Should be [b x s]
159+
print(
160+
f"[DEBUG] mask shape before checks: {mask.shape}, dtype: {mask.dtype}, type: {type(mask)}"
161+
)
162+
163+
# Ensure it's a tensor and preserve shape
164+
if not isinstance(mask, torch.Tensor):
165+
print(
166+
f"[DEBUG] WARNING: mask is not a tensor, converting from {type(mask)}"
167+
)
168+
mask = torch.tensor(mask, dtype=torch.bool)
169+
170+
# Ensure mask is always 2D
171+
if mask.dim() == 0:
172+
print(f"[DEBUG] WARNING: mask is 0D scalar, unsqueezing twice")
173+
mask = mask.unsqueeze(0).unsqueeze(0)
174+
elif mask.dim() == 1:
175+
print(
176+
f"[DEBUG] WARNING: mask is 1D with shape {mask.shape}, unsqueezing to 2D"
177+
)
178+
mask = mask.unsqueeze(0)
179+
180+
print(f"[DEBUG] mask final shape: {mask.shape}")
181+
print(
182+
f"[DEBUG] All shapes - request: {request.shape}, response: {response.shape}, "
183+
f"ref_logprobs: {ref_logprobs.shape}, advantages: {advantages.shape}, mask: {mask.shape}"
184+
)
109185

110186
input = {"tokens": torch.cat([request, response], dim=1)}
111187
target = {
@@ -116,6 +192,7 @@ def collate(
116192
}
117193
inputs.append(input)
118194
targets.append(target)
195+
119196
return inputs, targets
120197

121198

@@ -125,7 +202,7 @@ def simple_grpo_loss(
125202
ref_logprobs: torch.Tensor,
126203
advantages: torch.Tensor,
127204
padding_mask: torch.Tensor,
128-
beta: float = 0.1,
205+
beta: float = 0.005,
129206
) -> torch.Tensor:
130207
logprobs: torch.Tensor = compute_logprobs(logits, response)
131208
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
@@ -140,83 +217,120 @@ def simple_grpo_loss(
140217

141218
@dataclass
142219
class JuliaRewardActor(ForgeActor):
143-
"""Reward actor for Julia code execution using GenericOpenEnvActor."""
220+
"""Reward actor for Julia code execution using GenericOpenEnvActor.
221+
222+
Uses a dense reward structure:
223+
- 0.0: Code failed to execute or tests failed
224+
- reward > 0.0: Reward based on test success rate
225+
- 1.0: All tests passed
226+
"""
144227

145228
julia_env: GenericOpenEnvActor
146229

147230
@endpoint
148-
async def evaluate_response(
149-
self, prompt: str, response: str, target: dict
150-
) -> float:
231+
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
151232
"""
152233
Evaluate Julia code by executing it with test cases.
153234
154235
Args:
155236
prompt: The problem description (not used directly, but available)
156237
response: The Julia code to evaluate
157-
target: Dict containing test cases and expected outputs
238+
target: The Julia test code as a string
158239
159240
Returns:
160241
Reward score based on test case pass rate
161242
"""
243+
reward = 0.0
244+
245+
print("=" * 80)
246+
print("RAW RESPONSE FROM MODEL:")
247+
print("-" * 80)
248+
print(response)
249+
print("-" * 80)
250+
162251
try:
163252
# Extract code from markdown code blocks if present
164253
code = self._extract_code(response)
165254

166-
# Get test cases from target
167-
test_cases = target.get("test_cases", [])
168-
if not test_cases:
169-
record_metric("reward/julia/no_test_cases", 1, Reduce.SUM)
255+
if not code:
256+
print("No Julia code extracted - Reward: 0.0")
257+
print("=" * 80)
258+
record_metric("reward/julia/no_code_extracted", 1, Reduce.SUM)
259+
return 0.0
260+
261+
print("EXTRACTED JULIA CODE:")
262+
print("-" * 80)
263+
print(code)
264+
print("-" * 80)
265+
266+
# Use target as the test code directly
267+
if not target or not isinstance(target, str):
268+
print("No test code provided - Reward: 0.0")
269+
print("=" * 80)
270+
record_metric("reward/julia/no_test_code", 1, Reduce.SUM)
170271
return 0.0
171272

172-
# Execute code with test cases using JuliaAction
273+
# Execute code with test code using JuliaAction
274+
# The test code is the complete Julia test suite
173275
action = JuliaAction(
174-
code=code,
175-
test_cases=test_cases,
276+
core_code=code,
277+
test_code=target,
176278
)
177279

178-
result = await self.julia_env.execute.route(action)
280+
result = await self.julia_env.execute.call_one(action)
179281

180-
# Calculate reward based on test results
282+
# Extract reward from result
283+
reward = result.reward if result.reward is not None else 0.0
181284
obs = result.observation
285+
182286
passed = obs.tests_passed
183-
total = obs.tests_total
287+
failed = obs.tests_failed
288+
total = passed + failed
184289

185-
if total == 0:
186-
reward = 0.0
187-
else:
188-
# Pass rate as reward (0.0 to 1.0)
189-
reward = passed / total
290+
# Log execution details
291+
print("JuliaEnv Execution Result:")
292+
print(f" Reward: {reward:.3f}")
293+
print(f" Tests Passed: {passed}")
294+
print(f" Tests Failed: {failed}")
295+
print(f" Total Tests: {total}")
296+
297+
if obs.stderr:
298+
print(f" Stderr: {obs.stderr[:200]}")
299+
record_metric("reward/julia/has_errors", 1, Reduce.SUM)
300+
301+
if obs.error_message:
302+
print(f" Error Message: {obs.error_message[:200]}")
190303

191304
# Log metrics
192305
record_metric("reward/julia/tests_passed", passed, Reduce.SUM)
306+
record_metric("reward/julia/tests_failed", failed, Reduce.SUM)
193307
record_metric("reward/julia/tests_total", total, Reduce.SUM)
194308
record_metric("reward/julia/pass_rate", reward, Reduce.MEAN)
195309

196-
if obs.stderr:
197-
record_metric("reward/julia/has_errors", 1, Reduce.SUM)
310+
print(f"Final Reward: {reward:.3f}")
311+
print("=" * 80)
198312

199313
return reward
200314

315+
except asyncio.TimeoutError:
316+
print("✗ JuliaEnv request timeout - Reward: 0.0")
317+
print("=" * 80)
318+
record_metric("reward/julia/timeout_errors", 1, Reduce.SUM)
319+
return 0.0
201320
except Exception as e:
202-
print(f"Error evaluating Julia response: {e}")
321+
print(f"✗ Unexpected error: {e} - Reward: 0.0")
322+
print("=" * 80)
203323
record_metric("reward/julia/evaluation_errors", 1, Reduce.SUM)
204324
return 0.0
205325

206326
def _extract_code(self, response: str) -> str:
207-
"""Extract Julia code from markdown code blocks."""
208-
# Remove markdown code fences if present
209-
if "```julia" in response:
210-
start = response.find("```julia") + len("```julia")
211-
end = response.find("```", start)
212-
if end != -1:
213-
return response[start:end].strip()
214-
elif "```" in response:
215-
start = response.find("```") + len("```")
216-
end = response.find("```", start)
217-
if end != -1:
218-
return response[start:end].strip()
219-
return response.strip()
327+
"""Extract Julia code from markdown code blocks using regex."""
328+
import re
329+
330+
# Remove markdown code blocks with regex (more robust)
331+
text = re.sub(r"^```julia\s*\n?", "", response, flags=re.IGNORECASE)
332+
text = re.sub(r"\n?```\s*$", "", text)
333+
return text.strip()
220334

221335

222336
@dataclass
@@ -349,7 +463,12 @@ async def sample(self) -> dict[str, str] | None:
349463

350464
@endpoint
351465
async def pad_token(self):
352-
return self._tokenizer.pad_token_id
466+
# Use pad_token_id if available, otherwise use eos_token_id
467+
# Llama models don't have a pad token by default
468+
if self._tokenizer.pad_token_id is not None:
469+
return self._tokenizer.pad_token_id
470+
else:
471+
return self._tokenizer.eos_token_id
353472

354473

355474
async def drop_weights(version: int):

0 commit comments

Comments
 (0)