@@ -82,6 +82,43 @@ def response_tensor(self) -> torch.Tensor:
8282Policy = Generator
8383
8484
85+ # def collate(
86+ # batches: list[Group],
87+ # ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
88+ # """
89+ # Collates a list of batches into a single batch of inputs and targets.
90+ # Each batch is a list of episodes, and each episode is a dict of tensors.
91+ # """
92+ # inputs = []
93+ # targets = []
94+ # for batch in batches:
95+ # request = [e.request_tensor for e in batch]
96+ # request = torch.stack(request) # [b x s]
97+
98+ # response = [e.response_tensor for e in batch]
99+ # response = torch.stack(response) # [b x s]
100+
101+ # ref_logprobs = [e.ref_logprobs for e in batch]
102+ # ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
103+
104+ # advantages = [e.advantage for e in batch]
105+ # advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
106+
107+ # pad_id = batch[0].pad_id
108+ # mask = torch.ne(response, pad_id)
109+
110+ # input = {"tokens": torch.cat([request, response], dim=1)}
111+ # target = {
112+ # "response": response,
113+ # "ref_logprobs": ref_logprobs,
114+ # "advantages": advantages,
115+ # "padding_mask": mask,
116+ # }
117+ # inputs.append(input)
118+ # targets.append(target)
119+ # return inputs, targets
120+
121+
85122def collate (
86123 batches : list [Group ],
87124) -> tuple [list [dict [str , Any ]], list [dict [str , Any ]]]:
@@ -91,21 +128,60 @@ def collate(
91128 """
92129 inputs = []
93130 targets = []
94- for batch in batches :
131+ for batch_idx , batch in enumerate (batches ):
132+ print (f"[DEBUG] Processing batch { batch_idx } , len={ len (batch )} " )
133+
95134 request = [e .request_tensor for e in batch ]
96135 request = torch .stack (request ) # [b x s]
136+ print (f"[DEBUG] request shape: { request .shape } " )
97137
98138 response = [e .response_tensor for e in batch ]
99139 response = torch .stack (response ) # [b x s]
140+ print (f"[DEBUG] response shape: { response .shape } " )
100141
101142 ref_logprobs = [e .ref_logprobs for e in batch ]
102- ref_logprobs = torch .stack (ref_logprobs ).squeeze () # [b x s]
143+ ref_logprobs = torch .stack (ref_logprobs ) # [b x s]
144+
145+ # Only squeeze the first dimension if it exists and is size 1
146+ # This prevents squeezing the sequence dimension
147+ if ref_logprobs .dim () > 2 :
148+ ref_logprobs = ref_logprobs .squeeze (0 )
149+ print (f"[DEBUG] ref_logprobs shape after stack: { ref_logprobs .shape } " )
103150
104151 advantages = [e .advantage for e in batch ]
105152 advantages = torch .tensor (advantages ).unsqueeze (- 1 ) # [b x 1]
153+ print (f"[DEBUG] advantages shape: { advantages .shape } " )
106154
107155 pad_id = batch [0 ].pad_id
108- mask = response != pad_id
156+
157+ # Ensure mask is always a 2D tensor [b x s], even for single batch elements
158+ mask = torch .ne (response , pad_id ) # Should be [b x s]
159+ print (
160+ f"[DEBUG] mask shape before checks: { mask .shape } , dtype: { mask .dtype } , type: { type (mask )} "
161+ )
162+
163+ # Ensure it's a tensor and preserve shape
164+ if not isinstance (mask , torch .Tensor ):
165+ print (
166+ f"[DEBUG] WARNING: mask is not a tensor, converting from { type (mask )} "
167+ )
168+ mask = torch .tensor (mask , dtype = torch .bool )
169+
170+ # Ensure mask is always 2D
171+ if mask .dim () == 0 :
172+ print (f"[DEBUG] WARNING: mask is 0D scalar, unsqueezing twice" )
173+ mask = mask .unsqueeze (0 ).unsqueeze (0 )
174+ elif mask .dim () == 1 :
175+ print (
176+ f"[DEBUG] WARNING: mask is 1D with shape { mask .shape } , unsqueezing to 2D"
177+ )
178+ mask = mask .unsqueeze (0 )
179+
180+ print (f"[DEBUG] mask final shape: { mask .shape } " )
181+ print (
182+ f"[DEBUG] All shapes - request: { request .shape } , response: { response .shape } , "
183+ f"ref_logprobs: { ref_logprobs .shape } , advantages: { advantages .shape } , mask: { mask .shape } "
184+ )
109185
110186 input = {"tokens" : torch .cat ([request , response ], dim = 1 )}
111187 target = {
@@ -116,6 +192,7 @@ def collate(
116192 }
117193 inputs .append (input )
118194 targets .append (target )
195+
119196 return inputs , targets
120197
121198
@@ -125,7 +202,7 @@ def simple_grpo_loss(
125202 ref_logprobs : torch .Tensor ,
126203 advantages : torch .Tensor ,
127204 padding_mask : torch .Tensor ,
128- beta : float = 0.1 ,
205+ beta : float = 0.005 ,
129206) -> torch .Tensor :
130207 logprobs : torch .Tensor = compute_logprobs (logits , response )
131208 kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
@@ -140,83 +217,120 @@ def simple_grpo_loss(
140217
141218@dataclass
142219class JuliaRewardActor (ForgeActor ):
143- """Reward actor for Julia code execution using GenericOpenEnvActor."""
220+ """Reward actor for Julia code execution using GenericOpenEnvActor.
221+
222+ Uses a dense reward structure:
223+ - 0.0: Code failed to execute or tests failed
224+ - reward > 0.0: Reward based on test success rate
225+ - 1.0: All tests passed
226+ """
144227
145228 julia_env : GenericOpenEnvActor
146229
147230 @endpoint
148- async def evaluate_response (
149- self , prompt : str , response : str , target : dict
150- ) -> float :
231+ async def evaluate_response (self , prompt : str , response : str , target : str ) -> float :
151232 """
152233 Evaluate Julia code by executing it with test cases.
153234
154235 Args:
155236 prompt: The problem description (not used directly, but available)
156237 response: The Julia code to evaluate
157- target: Dict containing test cases and expected outputs
238+ target: The Julia test code as a string
158239
159240 Returns:
160241 Reward score based on test case pass rate
161242 """
243+ reward = 0.0
244+
245+ print ("=" * 80 )
246+ print ("RAW RESPONSE FROM MODEL:" )
247+ print ("-" * 80 )
248+ print (response )
249+ print ("-" * 80 )
250+
162251 try :
163252 # Extract code from markdown code blocks if present
164253 code = self ._extract_code (response )
165254
166- # Get test cases from target
167- test_cases = target .get ("test_cases" , [])
168- if not test_cases :
169- record_metric ("reward/julia/no_test_cases" , 1 , Reduce .SUM )
255+ if not code :
256+ print ("No Julia code extracted - Reward: 0.0" )
257+ print ("=" * 80 )
258+ record_metric ("reward/julia/no_code_extracted" , 1 , Reduce .SUM )
259+ return 0.0
260+
261+ print ("EXTRACTED JULIA CODE:" )
262+ print ("-" * 80 )
263+ print (code )
264+ print ("-" * 80 )
265+
266+ # Use target as the test code directly
267+ if not target or not isinstance (target , str ):
268+ print ("No test code provided - Reward: 0.0" )
269+ print ("=" * 80 )
270+ record_metric ("reward/julia/no_test_code" , 1 , Reduce .SUM )
170271 return 0.0
171272
172- # Execute code with test cases using JuliaAction
273+ # Execute code with test code using JuliaAction
274+ # The test code is the complete Julia test suite
173275 action = JuliaAction (
174- code = code ,
175- test_cases = test_cases ,
276+ core_code = code ,
277+ test_code = target ,
176278 )
177279
178- result = await self .julia_env .execute .route (action )
280+ result = await self .julia_env .execute .call_one (action )
179281
180- # Calculate reward based on test results
282+ # Extract reward from result
283+ reward = result .reward if result .reward is not None else 0.0
181284 obs = result .observation
285+
182286 passed = obs .tests_passed
183- total = obs .tests_total
287+ failed = obs .tests_failed
288+ total = passed + failed
184289
185- if total == 0 :
186- reward = 0.0
187- else :
188- # Pass rate as reward (0.0 to 1.0)
189- reward = passed / total
290+ # Log execution details
291+ print ("JuliaEnv Execution Result:" )
292+ print (f" Reward: { reward :.3f} " )
293+ print (f" Tests Passed: { passed } " )
294+ print (f" Tests Failed: { failed } " )
295+ print (f" Total Tests: { total } " )
296+
297+ if obs .stderr :
298+ print (f" Stderr: { obs .stderr [:200 ]} " )
299+ record_metric ("reward/julia/has_errors" , 1 , Reduce .SUM )
300+
301+ if obs .error_message :
302+ print (f" Error Message: { obs .error_message [:200 ]} " )
190303
191304 # Log metrics
192305 record_metric ("reward/julia/tests_passed" , passed , Reduce .SUM )
306+ record_metric ("reward/julia/tests_failed" , failed , Reduce .SUM )
193307 record_metric ("reward/julia/tests_total" , total , Reduce .SUM )
194308 record_metric ("reward/julia/pass_rate" , reward , Reduce .MEAN )
195309
196- if obs . stderr :
197- record_metric ( "reward/julia/has_errors" , 1 , Reduce . SUM )
310+ print ( f"Final Reward: { reward :.3f } " )
311+ print ( "=" * 80 )
198312
199313 return reward
200314
315+ except asyncio .TimeoutError :
316+ print ("✗ JuliaEnv request timeout - Reward: 0.0" )
317+ print ("=" * 80 )
318+ record_metric ("reward/julia/timeout_errors" , 1 , Reduce .SUM )
319+ return 0.0
201320 except Exception as e :
202- print (f"Error evaluating Julia response: { e } " )
321+ print (f"✗ Unexpected error: { e } - Reward: 0.0" )
322+ print ("=" * 80 )
203323 record_metric ("reward/julia/evaluation_errors" , 1 , Reduce .SUM )
204324 return 0.0
205325
206326 def _extract_code (self , response : str ) -> str :
207- """Extract Julia code from markdown code blocks."""
208- # Remove markdown code fences if present
209- if "```julia" in response :
210- start = response .find ("```julia" ) + len ("```julia" )
211- end = response .find ("```" , start )
212- if end != - 1 :
213- return response [start :end ].strip ()
214- elif "```" in response :
215- start = response .find ("```" ) + len ("```" )
216- end = response .find ("```" , start )
217- if end != - 1 :
218- return response [start :end ].strip ()
219- return response .strip ()
327+ """Extract Julia code from markdown code blocks using regex."""
328+ import re
329+
330+ # Remove markdown code blocks with regex (more robust)
331+ text = re .sub (r"^```julia\s*\n?" , "" , response , flags = re .IGNORECASE )
332+ text = re .sub (r"\n?```\s*$" , "" , text )
333+ return text .strip ()
220334
221335
222336@dataclass
@@ -349,7 +463,12 @@ async def sample(self) -> dict[str, str] | None:
349463
350464 @endpoint
351465 async def pad_token (self ):
352- return self ._tokenizer .pad_token_id
466+ # Use pad_token_id if available, otherwise use eos_token_id
467+ # Llama models don't have a pad token by default
468+ if self ._tokenizer .pad_token_id is not None :
469+ return self ._tokenizer .pad_token_id
470+ else :
471+ return self ._tokenizer .eos_token_id
353472
354473
355474async def drop_weights (version : int ):
0 commit comments