2121Uses SIGLIP for vision encoding and Gemma 3 270M for language processing.
2222"""
2323
24+ import logging
25+ import math
26+
27+ import numpy as np
2428import torch
2529import torch .nn .functional as F # noqa: N812
2630from einops import rearrange
@@ -228,9 +232,8 @@ def predict_value(self, batch: dict[str, Tensor]) -> Tensor:
228232
229233 images , img_masks = self .prepare_images (batch )
230234 lang_tokens , lang_masks = self .prepare_language (batch )
231- state = batch .get ("state" )
232235
233- logits = self .model .forward (images , img_masks , lang_tokens , lang_masks , state )
236+ logits = self .model .get_value (images , img_masks , lang_tokens , lang_masks )
234237 return self .calculate_value (logits )
235238
236239 def forward (self , batch : dict [str , Tensor ]) -> tuple [Tensor , dict [str , Tensor ] | None ]:
@@ -246,22 +249,56 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] |
246249
247250 images , img_masks = self .prepare_images (batch )
248251 lang_tokens , lang_masks = self .prepare_language (batch )
249- state = batch . get ( "state" )
252+ response_tokens , response_masks = self . prepare_response ( batch )
250253
251- logits = self .model .forward (images , img_masks , lang_tokens , lang_masks , state )
252- values = self .calculate_value (logits )
254+ value_logits , response_logits = self .model .forward (
255+ images , img_masks , lang_tokens , lang_masks , response_tokens , response_masks
256+ )
257+ values = self .calculate_value (value_logits )
253258 # Compute Cross-Entropy loss
254- logits = logits .to (dtype = torch .float32 ) # upcast to float32 for loss calculation
259+ value_logits = value_logits .to (dtype = torch .float32 ) # upcast to float32 for loss calculation
255260 batch ["return_bin_idx" ] = batch ["return_bin_idx" ].to (dtype = torch .long )
256- loss = F .cross_entropy (logits , batch ["return_bin_idx" ])
261+ value_ce_loss = F .cross_entropy (value_logits , batch ["return_bin_idx" ], reduction = "none" )
262+
263+ action_is_pad = batch .get ("action_is_pad" )
264+
265+ # mask for differentiating between robotic and VQA datasets
266+ diff_mask = action_is_pad .all (dim = 1 )
267+ # Mask CE loss if all action_is_pad are true. This is used for VQA dataset where we don't have actions tokens.
268+ value_ce_loss = value_ce_loss * (~ diff_mask ).float ()
269+
270+ value_ce_loss = value_ce_loss .mean ()
257271
258272 l1_loss = F .l1_loss (values , batch ["return_continuous" ])
259273
260- accuracy = (logits .argmax (dim = - 1 ) == batch ["return_bin_idx" ]).float ().mean ()
274+ # Accuracy only over robotic samples (exclude VQA where diff_mask is True)
275+ robotic_mask = ~ diff_mask
276+ correct = (value_logits .argmax (dim = - 1 ) == batch ["return_bin_idx" ]).float () * robotic_mask .float ()
277+ num_robotic = robotic_mask .float ().sum ()
278+ accuracy = correct .sum () / num_robotic .clamp (min = 1 )
279+
280+ batch_size , seq_len = response_logits .shape [0 ], response_logits .shape [1 ]
281+ response_slice = slice (1 , None )
282+ response_logits = response_logits .to (dtype = torch .float32 ) # upcast to float32 for loss calculation
283+ response_logits = rearrange (response_logits , "b s d -> (b s) d" )
284+ response_labels = rearrange (response_tokens [:, response_slice ], "b s -> (b s)" )
285+ response_ce_loss = F .cross_entropy (response_logits , response_labels , reduction = "none" )
286+
287+ response_ce_loss = rearrange (response_ce_loss , "(b s) -> b s" , b = batch_size , s = seq_len )
288+
289+ # remove pad tokens
290+ response_is_pad = ~ response_masks # convert into format where value for pad is True
291+ # Mask response loss if response is padded
292+ response_ce_loss = response_ce_loss * ~ response_is_pad [:, response_slice ]
293+ # Mask response loss if all action_is_pad are true. This is used for Robotic dataset where we have at least one actions tokens.
294+ response_ce_loss = response_ce_loss * rearrange (diff_mask .float (), "b -> b 1" )
295+
296+ # compute mean
297+ response_ce_loss = response_ce_loss .mean ()
261298
262299 return {
263- "MSE" : torch .zeros_like (loss , requires_grad = False ),
264- "CE" : loss ,
300+ "MSE" : torch .zeros_like (value_ce_loss , requires_grad = False ),
301+ "CE" : value_ce_loss + response_ce_loss ,
265302 "L1" : l1_loss ,
266303 "Accuracy" : accuracy ,
267304 }
@@ -321,6 +358,35 @@ def prepare_images(self, batch):
321358
322359 return images , img_masks
323360
361+ def prepare_discrete_state (self , batch : dict [str , Tensor ]) -> list [str ]:
362+ """Discretizes the state into bins and converts it to a string representation.
363+
364+ Each dimension of the state vector is discretized into 256 bins.
365+ The values of each dimension of the state are expected to be in the range [-1, 1].
366+ The discretization bins are linearly spaced between -1 and 1.
367+ The index of the bin for each dimension is then concatenated into a space-separated string.
368+
369+ Args:
370+ batch: Batch of data containing the "state" tensor.
371+
372+ Returns:
373+ A list of strings, where each string is a space-separated list of discretized state values.
374+
375+ Raises:
376+ ValueError: If the state values are not normalized between -1 and 1.
377+ """
378+ state = batch ["state" ]
379+ state_np = state .to (device = "cpu" , dtype = torch .float32 ).numpy ()
380+ if np .any (state_np < - 1.0 ) or np .any (state_np > 1.0 ):
381+ logging .warning (
382+ f"State values are not normalized between -1 and 1. Min: { state_np .min ()} , Max: { state_np .max ()} "
383+ )
384+ state_np = np .clip (state_np , - 1.0 , 1.0 )
385+ discretized_states = np .digitize (state_np , bins = np .linspace (- 1 , 1 , 256 + 1 )[:- 1 ]) - 1
386+ return [
387+ " " .join (map (str , row )) for row in discretized_states
388+ ] # TODO: return a tensor instead of a list of strings?
389+
324390 def prepare_language (self , batch ) -> tuple [Tensor , Tensor ]:
325391 """Tokenizes the text input for the model.
326392
@@ -333,21 +399,54 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
333399 device = batch .get ("state" , list (batch .values ())[0 ]).device
334400 tasks = batch ["prompt" ]
335401
336- # PaliGemma prompt has to end with a new line
337- tasks = [task if task .endswith ("\n " ) else f"{ task } \n " for task in tasks ]
402+ state = self .prepare_discrete_state (batch )
403+ # using <eos> to separate each modality
404+ prompt = [f"Task: { task } <eos>State: { state } <eos>" for task , state in zip (tasks , state , strict = False )]
338405
339406 tokenized_prompt = self .language_tokenizer .__call__ (
340- tasks ,
407+ prompt ,
341408 padding = "max_length" ,
342409 padding_side = "right" ,
343- max_length = self .config .tokenizer_max_length ,
410+ max_length = self .config .prompt_max_length ,
344411 return_tensors = "pt" ,
412+ truncation = True ,
345413 )
346414 lang_tokens = tokenized_prompt ["input_ids" ].to (device = device )
347415 lang_masks = tokenized_prompt ["attention_mask" ].to (device = device , dtype = torch .bool )
348416
349417 return lang_tokens , lang_masks
350418
419+ def prepare_response (self , batch : dict [str , Tensor ]) -> tuple [Tensor , Tensor ]:
420+ """Tokenize the response input.
421+
422+ Args:
423+ batch: Batch of data containing the key "response".
424+
425+ Returns:
426+ A tuple containing:
427+ - response_tokens: Tensor of response language tokens.
428+ - response_masks: Tensor of response language attention masks.
429+ """
430+
431+ device = batch ["state" ].device
432+ responses = batch ["response" ]
433+
434+ # if '' is found in response then response is not for loss calculation (used for robotic dataset with no subtask), so add pad token to the response.
435+ response_prompt = [f"{ response } " for response in responses ]
436+
437+ tokenized_response = self .language_tokenizer .__call__ (
438+ response_prompt ,
439+ padding = "max_length" ,
440+ padding_side = "right" ,
441+ max_length = self .config .response_max_length ,
442+ return_tensors = "pt" ,
443+ truncation = True ,
444+ )
445+ response_tokens = tokenized_response ["input_ids" ].to (device = device )
446+ response_masks = tokenized_response ["attention_mask" ].to (device = device , dtype = torch .bool )
447+
448+ return response_tokens , response_masks
449+
351450
352451class ValueModel (nn .Module ):
353452 """
@@ -376,8 +475,6 @@ class ValueModel(nn.Module):
376475 └──────────────────────────────┘
377476 """
378477
379- CLASSIFICATION_TOKEN_ID = 6 # unused token id in Gemma 3 270M that we repurpose for classification
380-
381478 def __init__ (self , config ):
382479 """Initializes the ValueModel.
383480
@@ -388,7 +485,8 @@ def __init__(self, config):
388485 self .config = config
389486
390487 siglip_gemma_value_config = SiglipGemmaValueConfig (
391- num_value_bins = self .config .reward_config .number_of_bins
488+ num_value_bins = self .config .reward_config .number_of_bins ,
489+ response_max_length = self .config .response_max_length ,
392490 )
393491 self .siglip_gemma_value = SiglipGemmaValueModel (siglip_gemma_value_config )
394492
@@ -399,7 +497,13 @@ def __init__(self, config):
399497 self .c_neg = config .reward_config .C_neg
400498
401499 def embed_sequence (
402- self , images , img_masks , lang_tokens , lang_masks , state
500+ self ,
501+ images ,
502+ img_masks ,
503+ lang_tokens ,
504+ lang_masks ,
505+ response_tokens : torch .Tensor | None = None ,
506+ response_masks : torch .Tensor | None = None ,
403507 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
404508 """Embeds sequence of images and language tokens.
405509
@@ -451,25 +555,19 @@ def embed_sequence(
451555 num_lang_embs = lang_emb .shape [1 ]
452556 att_masks += [0 ] * num_lang_embs
453557
454- # embed state
455- state_emb = self .state_proj (state )
456- state_emb = state_emb .to (dtype = torch .bfloat16 )
457- embs .append (state_emb [:, None , :])
558+ if response_tokens is not None :
559+ response_emb = self .siglip_gemma_value .embed_language_tokens (response_tokens )
458560
459- state_mask = torch .ones (state_emb .shape [0 ], 1 , dtype = torch .bool , device = state_emb .device )
460- pad_masks .append (state_mask )
561+ # Normalize response language embeddings
562+ response_emb_dim = response_emb .shape [- 1 ]
563+ response_emb = response_emb * math .sqrt (response_emb_dim )
461564
462- # full attention between state and image and language inputs
463- att_masks += [ 0 ]
565+ embs . append ( response_emb )
566+ pad_masks . append ( response_masks )
464567
465- # add classification token
466- cls_token = torch .full (
467- (bsize , 1 ), self .CLASSIFICATION_TOKEN_ID , device = state_emb .device , dtype = torch .long
468- )
469- cls_token_emb = self .siglip_gemma_value .gemma .embed_tokens (cls_token )
470- embs .append (cls_token_emb )
471- pad_masks .append (torch .ones (bsize , 1 , dtype = torch .bool , device = state_emb .device ))
472- att_masks += [0 ]
568+ # full attention between image, language and response inputs
569+ num_response_embs = response_emb .shape [1 ]
570+ att_masks += [1 ] * num_response_embs
473571
474572 embs = torch .cat (embs , dim = 1 )
475573 pad_masks = torch .cat (pad_masks , dim = 1 )
@@ -484,7 +582,42 @@ def forward(
484582 img_masks : list [torch .Tensor ],
485583 lang_tokens : torch .Tensor ,
486584 lang_masks : torch .Tensor ,
487- state : torch .Tensor | None = None ,
585+ response_tokens : torch .Tensor | None = None ,
586+ response_masks : torch .Tensor | None = None ,
587+ ) -> torch .Tensor :
588+ """Predict value estimates given observations.
589+
590+ Args:
591+ images: List of image tensors
592+ img_masks: List of image masks
593+ lang_tokens: Language token IDs
594+ lang_masks: Language attention masks
595+ state: Optional state tensor
596+
597+ Returns:
598+ Tensor of shape [batch_size, 1] containing value estimates
599+ """
600+ embs , pad_masks , att_masks = self .embed_sequence (
601+ images , img_masks , lang_tokens , lang_masks , response_tokens , response_masks
602+ )
603+
604+ att_2d_masks = make_att_2d_masks (pad_masks , att_masks )
605+ position_ids = torch .cumsum (pad_masks , dim = 1 ) - 1
606+
607+ value_logits , response_logits = self .siglip_gemma_value .forward (
608+ inputs_embeds = embs ,
609+ attention_mask = att_2d_masks ,
610+ position_ids = position_ids ,
611+ )
612+
613+ return value_logits , response_logits
614+
615+ def get_value (
616+ self ,
617+ images : list [torch .Tensor ],
618+ img_masks : list [torch .Tensor ],
619+ lang_tokens : torch .Tensor ,
620+ lang_masks : torch .Tensor ,
488621 ) -> torch .Tensor :
489622 """Predict value estimates given observations.
490623
@@ -498,15 +631,15 @@ def forward(
498631 Returns:
499632 Tensor of shape [batch_size, 1] containing value estimates
500633 """
501- embs , pad_masks , att_masks = self .embed_sequence (images , img_masks , lang_tokens , lang_masks , state )
634+ embs , pad_masks , att_masks = self .embed_sequence (images , img_masks , lang_tokens , lang_masks )
502635
503636 att_2d_masks = make_att_2d_masks (pad_masks , att_masks )
504637 position_ids = torch .cumsum (pad_masks , dim = 1 ) - 1
505638
506- logits = self .siglip_gemma_value .forward (
639+ value_logits = self .siglip_gemma_value .get_value (
507640 inputs_embeds = embs ,
508641 attention_mask = att_2d_masks ,
509642 position_ids = position_ids ,
510643 )
511644
512- return logits
645+ return value_logits
0 commit comments