Skip to content

Commit f9f4b05

Browse files
authored
Feat/value response (#109)
1 parent c72756a commit f9f4b05

File tree

6 files changed

+296
-71
lines changed

6 files changed

+296
-71
lines changed

configs/examples/value_config.json

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
"dataset_mixture": {
33
"datasets": [
44
{
5-
"repo_id": "physical-intelligence/libero",
6-
"episodes": [0,1,2,3,4,5,6,7,8,9]
5+
"repo_id": "physical-intelligence/libero"
6+
},
7+
{
8+
"grounding": "clevr"
79
}
810
],
911
"weights": [
12+
1.0,
1013
1.0
1114
],
1215
"action_freq": 30.0,
@@ -22,7 +25,8 @@
2225
"VALUE": "MEAN_STD"
2326
},
2427
"max_state_dim": 32,
25-
"tokenizer_max_length": 52,
28+
"prompt_max_length": 256,
29+
"response_max_length": 52,
2630
"reward_config": {
2731
"number_of_bins": 201,
2832
"C_neg": -1000.0,
@@ -41,10 +45,11 @@
4145
"gradient_accumulation_steps": 1,
4246
"dataloader_batch_size": 2,
4347
"prefetch_factor": 2,
44-
"steps": 100,
45-
"log_freq": 1,
48+
"steps": 10000,
49+
"log_freq": 100,
50+
"val_freq": 500,
4651
"save_checkpoint": true,
47-
"save_freq": 100,
52+
"save_freq": 1000,
4853
"use_policy_training_preset": false,
4954
"trace_nans": true,
5055
"optimizer": {

src/opentau/policies/value/configuration_value.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class ValueConfig(PreTrainedConfig):
7777
empty_cameras: int = 0
7878

7979
# Tokenizer
80-
tokenizer_max_length: int = 48
80+
prompt_max_length: int = 48
81+
response_max_length: int = 52
8182

8283
# Reward config
8384
reward_config: RewardConfig = field(default_factory=RewardConfig)

src/opentau/policies/value/modeling_value.py

Lines changed: 171 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
Uses SIGLIP for vision encoding and Gemma 3 270M for language processing.
2222
"""
2323

24+
import logging
25+
import math
26+
27+
import numpy as np
2428
import torch
2529
import torch.nn.functional as F # noqa: N812
2630
from einops import rearrange
@@ -228,9 +232,8 @@ def predict_value(self, batch: dict[str, Tensor]) -> Tensor:
228232

229233
images, img_masks = self.prepare_images(batch)
230234
lang_tokens, lang_masks = self.prepare_language(batch)
231-
state = batch.get("state")
232235

233-
logits = self.model.forward(images, img_masks, lang_tokens, lang_masks, state)
236+
logits = self.model.get_value(images, img_masks, lang_tokens, lang_masks)
234237
return self.calculate_value(logits)
235238

236239
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | None]:
@@ -246,22 +249,56 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] |
246249

247250
images, img_masks = self.prepare_images(batch)
248251
lang_tokens, lang_masks = self.prepare_language(batch)
249-
state = batch.get("state")
252+
response_tokens, response_masks = self.prepare_response(batch)
250253

251-
logits = self.model.forward(images, img_masks, lang_tokens, lang_masks, state)
252-
values = self.calculate_value(logits)
254+
value_logits, response_logits = self.model.forward(
255+
images, img_masks, lang_tokens, lang_masks, response_tokens, response_masks
256+
)
257+
values = self.calculate_value(value_logits)
253258
# Compute Cross-Entropy loss
254-
logits = logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
259+
value_logits = value_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
255260
batch["return_bin_idx"] = batch["return_bin_idx"].to(dtype=torch.long)
256-
loss = F.cross_entropy(logits, batch["return_bin_idx"])
261+
value_ce_loss = F.cross_entropy(value_logits, batch["return_bin_idx"], reduction="none")
262+
263+
action_is_pad = batch.get("action_is_pad")
264+
265+
# mask for differentiating between robotic and VQA datasets
266+
diff_mask = action_is_pad.all(dim=1)
267+
# Mask CE loss if all action_is_pad are true. This is used for VQA dataset where we don't have actions tokens.
268+
value_ce_loss = value_ce_loss * (~diff_mask).float()
269+
270+
value_ce_loss = value_ce_loss.mean()
257271

258272
l1_loss = F.l1_loss(values, batch["return_continuous"])
259273

260-
accuracy = (logits.argmax(dim=-1) == batch["return_bin_idx"]).float().mean()
274+
# Accuracy only over robotic samples (exclude VQA where diff_mask is True)
275+
robotic_mask = ~diff_mask
276+
correct = (value_logits.argmax(dim=-1) == batch["return_bin_idx"]).float() * robotic_mask.float()
277+
num_robotic = robotic_mask.float().sum()
278+
accuracy = correct.sum() / num_robotic.clamp(min=1)
279+
280+
batch_size, seq_len = response_logits.shape[0], response_logits.shape[1]
281+
response_slice = slice(1, None)
282+
response_logits = response_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
283+
response_logits = rearrange(response_logits, "b s d -> (b s) d")
284+
response_labels = rearrange(response_tokens[:, response_slice], "b s -> (b s)")
285+
response_ce_loss = F.cross_entropy(response_logits, response_labels, reduction="none")
286+
287+
response_ce_loss = rearrange(response_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len)
288+
289+
# remove pad tokens
290+
response_is_pad = ~response_masks # convert into format where value for pad is True
291+
# Mask response loss if response is padded
292+
response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice]
293+
# Mask response loss if all action_is_pad are true. This is used for Robotic dataset where we have at least one actions tokens.
294+
response_ce_loss = response_ce_loss * rearrange(diff_mask.float(), "b -> b 1")
295+
296+
# compute mean
297+
response_ce_loss = response_ce_loss.mean()
261298

262299
return {
263-
"MSE": torch.zeros_like(loss, requires_grad=False),
264-
"CE": loss,
300+
"MSE": torch.zeros_like(value_ce_loss, requires_grad=False),
301+
"CE": value_ce_loss + response_ce_loss,
265302
"L1": l1_loss,
266303
"Accuracy": accuracy,
267304
}
@@ -321,6 +358,35 @@ def prepare_images(self, batch):
321358

322359
return images, img_masks
323360

361+
def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]:
362+
"""Discretizes the state into bins and converts it to a string representation.
363+
364+
Each dimension of the state vector is discretized into 256 bins.
365+
The values of each dimension of the state are expected to be in the range [-1, 1].
366+
The discretization bins are linearly spaced between -1 and 1.
367+
The index of the bin for each dimension is then concatenated into a space-separated string.
368+
369+
Args:
370+
batch: Batch of data containing the "state" tensor.
371+
372+
Returns:
373+
A list of strings, where each string is a space-separated list of discretized state values.
374+
375+
Raises:
376+
ValueError: If the state values are not normalized between -1 and 1.
377+
"""
378+
state = batch["state"]
379+
state_np = state.to(device="cpu", dtype=torch.float32).numpy()
380+
if np.any(state_np < -1.0) or np.any(state_np > 1.0):
381+
logging.warning(
382+
f"State values are not normalized between -1 and 1. Min: {state_np.min()}, Max: {state_np.max()}"
383+
)
384+
state_np = np.clip(state_np, -1.0, 1.0)
385+
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
386+
return [
387+
" ".join(map(str, row)) for row in discretized_states
388+
] # TODO: return a tensor instead of a list of strings?
389+
324390
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
325391
"""Tokenizes the text input for the model.
326392
@@ -333,21 +399,54 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
333399
device = batch.get("state", list(batch.values())[0]).device
334400
tasks = batch["prompt"]
335401

336-
# PaliGemma prompt has to end with a new line
337-
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
402+
state = self.prepare_discrete_state(batch)
403+
# using <eos> to separate each modality
404+
prompt = [f"Task: {task}<eos>State: {state}<eos>" for task, state in zip(tasks, state, strict=False)]
338405

339406
tokenized_prompt = self.language_tokenizer.__call__(
340-
tasks,
407+
prompt,
341408
padding="max_length",
342409
padding_side="right",
343-
max_length=self.config.tokenizer_max_length,
410+
max_length=self.config.prompt_max_length,
344411
return_tensors="pt",
412+
truncation=True,
345413
)
346414
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
347415
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
348416

349417
return lang_tokens, lang_masks
350418

419+
def prepare_response(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
420+
"""Tokenize the response input.
421+
422+
Args:
423+
batch: Batch of data containing the key "response".
424+
425+
Returns:
426+
A tuple containing:
427+
- response_tokens: Tensor of response language tokens.
428+
- response_masks: Tensor of response language attention masks.
429+
"""
430+
431+
device = batch["state"].device
432+
responses = batch["response"]
433+
434+
# if '' is found in response then response is not for loss calculation (used for robotic dataset with no subtask), so add pad token to the response.
435+
response_prompt = [f"{response}" for response in responses]
436+
437+
tokenized_response = self.language_tokenizer.__call__(
438+
response_prompt,
439+
padding="max_length",
440+
padding_side="right",
441+
max_length=self.config.response_max_length,
442+
return_tensors="pt",
443+
truncation=True,
444+
)
445+
response_tokens = tokenized_response["input_ids"].to(device=device)
446+
response_masks = tokenized_response["attention_mask"].to(device=device, dtype=torch.bool)
447+
448+
return response_tokens, response_masks
449+
351450

352451
class ValueModel(nn.Module):
353452
"""
@@ -376,8 +475,6 @@ class ValueModel(nn.Module):
376475
└──────────────────────────────┘
377476
"""
378477

379-
CLASSIFICATION_TOKEN_ID = 6 # unused token id in Gemma 3 270M that we repurpose for classification
380-
381478
def __init__(self, config):
382479
"""Initializes the ValueModel.
383480
@@ -388,7 +485,8 @@ def __init__(self, config):
388485
self.config = config
389486

390487
siglip_gemma_value_config = SiglipGemmaValueConfig(
391-
num_value_bins=self.config.reward_config.number_of_bins
488+
num_value_bins=self.config.reward_config.number_of_bins,
489+
response_max_length=self.config.response_max_length,
392490
)
393491
self.siglip_gemma_value = SiglipGemmaValueModel(siglip_gemma_value_config)
394492

@@ -399,7 +497,13 @@ def __init__(self, config):
399497
self.c_neg = config.reward_config.C_neg
400498

401499
def embed_sequence(
402-
self, images, img_masks, lang_tokens, lang_masks, state
500+
self,
501+
images,
502+
img_masks,
503+
lang_tokens,
504+
lang_masks,
505+
response_tokens: torch.Tensor | None = None,
506+
response_masks: torch.Tensor | None = None,
403507
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
404508
"""Embeds sequence of images and language tokens.
405509
@@ -451,25 +555,19 @@ def embed_sequence(
451555
num_lang_embs = lang_emb.shape[1]
452556
att_masks += [0] * num_lang_embs
453557

454-
# embed state
455-
state_emb = self.state_proj(state)
456-
state_emb = state_emb.to(dtype=torch.bfloat16)
457-
embs.append(state_emb[:, None, :])
558+
if response_tokens is not None:
559+
response_emb = self.siglip_gemma_value.embed_language_tokens(response_tokens)
458560

459-
state_mask = torch.ones(state_emb.shape[0], 1, dtype=torch.bool, device=state_emb.device)
460-
pad_masks.append(state_mask)
561+
# Normalize response language embeddings
562+
response_emb_dim = response_emb.shape[-1]
563+
response_emb = response_emb * math.sqrt(response_emb_dim)
461564

462-
# full attention between state and image and language inputs
463-
att_masks += [0]
565+
embs.append(response_emb)
566+
pad_masks.append(response_masks)
464567

465-
# add classification token
466-
cls_token = torch.full(
467-
(bsize, 1), self.CLASSIFICATION_TOKEN_ID, device=state_emb.device, dtype=torch.long
468-
)
469-
cls_token_emb = self.siglip_gemma_value.gemma.embed_tokens(cls_token)
470-
embs.append(cls_token_emb)
471-
pad_masks.append(torch.ones(bsize, 1, dtype=torch.bool, device=state_emb.device))
472-
att_masks += [0]
568+
# full attention between image, language and response inputs
569+
num_response_embs = response_emb.shape[1]
570+
att_masks += [1] * num_response_embs
473571

474572
embs = torch.cat(embs, dim=1)
475573
pad_masks = torch.cat(pad_masks, dim=1)
@@ -484,7 +582,42 @@ def forward(
484582
img_masks: list[torch.Tensor],
485583
lang_tokens: torch.Tensor,
486584
lang_masks: torch.Tensor,
487-
state: torch.Tensor | None = None,
585+
response_tokens: torch.Tensor | None = None,
586+
response_masks: torch.Tensor | None = None,
587+
) -> torch.Tensor:
588+
"""Predict value estimates given observations.
589+
590+
Args:
591+
images: List of image tensors
592+
img_masks: List of image masks
593+
lang_tokens: Language token IDs
594+
lang_masks: Language attention masks
595+
state: Optional state tensor
596+
597+
Returns:
598+
Tensor of shape [batch_size, 1] containing value estimates
599+
"""
600+
embs, pad_masks, att_masks = self.embed_sequence(
601+
images, img_masks, lang_tokens, lang_masks, response_tokens, response_masks
602+
)
603+
604+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
605+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
606+
607+
value_logits, response_logits = self.siglip_gemma_value.forward(
608+
inputs_embeds=embs,
609+
attention_mask=att_2d_masks,
610+
position_ids=position_ids,
611+
)
612+
613+
return value_logits, response_logits
614+
615+
def get_value(
616+
self,
617+
images: list[torch.Tensor],
618+
img_masks: list[torch.Tensor],
619+
lang_tokens: torch.Tensor,
620+
lang_masks: torch.Tensor,
488621
) -> torch.Tensor:
489622
"""Predict value estimates given observations.
490623
@@ -498,15 +631,15 @@ def forward(
498631
Returns:
499632
Tensor of shape [batch_size, 1] containing value estimates
500633
"""
501-
embs, pad_masks, att_masks = self.embed_sequence(images, img_masks, lang_tokens, lang_masks, state)
634+
embs, pad_masks, att_masks = self.embed_sequence(images, img_masks, lang_tokens, lang_masks)
502635

503636
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
504637
position_ids = torch.cumsum(pad_masks, dim=1) - 1
505638

506-
logits = self.siglip_gemma_value.forward(
639+
value_logits = self.siglip_gemma_value.get_value(
507640
inputs_embeds=embs,
508641
attention_mask=att_2d_masks,
509642
position_ids=position_ids,
510643
)
511644

512-
return logits
645+
return value_logits

0 commit comments

Comments
 (0)