Skip to content

Commit c72756a

Browse files
authored
Deprecating loss type mapping (#120)
1 parent ae4d594 commit c72756a

File tree

5 files changed

+6
-176
lines changed

5 files changed

+6
-176
lines changed

src/opentau/configs/default.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from opentau import (
3232
policies, # noqa: F401
3333
)
34-
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING, LOSS_TYPE_MAPPING
3534
from opentau.datasets.transforms import ImageTransformsConfig
3635
from opentau.datasets.video_utils import get_safe_default_codec
3736

@@ -70,14 +69,11 @@ class DatasetConfig:
7069
stats: Dictionary of statistics for normalization, keyed by feature name.
7170
Each value is a dictionary with 'mean' and 'std' arrays. Defaults to None.
7271
data_features_name_mapping: Optional mapping from dataset feature names to
73-
standard feature names. Must be provided together with `loss_type_mapping`.
74-
Defaults to None.
75-
loss_type_mapping: Optional loss type mapping for the dataset. Must be
76-
provided together with `data_features_name_mapping`. Defaults to None.
72+
standard feature names. Defaults to None.
7773
7874
Raises:
7975
ValueError: If both or neither of `repo_id` and `grounding` are set, or
80-
if only one of `data_features_name_mapping` and `loss_type_mapping`
76+
if `data_features_name_mapping` is provided.
8177
is provided.
8278
"""
8379

@@ -94,7 +90,6 @@ class DatasetConfig:
9490

9591
# optional standard data format mapping for the dataset if mapping is not already in standard_data_format_mapping.py
9692
data_features_name_mapping: dict[str, str] | None = None
97-
loss_type_mapping: str | None = None
9893

9994
# Ratio of the dataset to be used for validation. Please specify a value.
10095
# If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
@@ -106,16 +101,7 @@ def __post_init__(self):
106101
if (self.repo_id is None) == (self.grounding is None):
107102
raise ValueError("Exactly one of `repo_id` or `grounding` for Dataset config should be set.")
108103

109-
# data_features_name_mapping and loss_type_mapping have to be provided together
110-
if (self.data_features_name_mapping is None) != (self.loss_type_mapping is None):
111-
raise ValueError(
112-
"`data_features_name_mapping` and `loss_type_mapping` have to be provided together."
113-
)
114-
115-
# add data_features_name_mapping and loss_type_mapping to standard_data_format_mapping.py if they are provided
116-
if self.data_features_name_mapping is not None and self.loss_type_mapping is not None:
117-
DATA_FEATURES_NAME_MAPPING[self.repo_id] = self.data_features_name_mapping
118-
LOSS_TYPE_MAPPING[self.repo_id] = self.loss_type_mapping
104+
# data_features_name_mapping have to be provided if it is not already in standard_data_format_mapping.py
119105

120106

121107
@dataclass

src/opentau/datasets/lerobot_dataset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
from opentau.constants import HF_OPENTAU_HOME
109109
from opentau.datasets.compute_stats import aggregate_stats, compute_episode_stats
110110
from opentau.datasets.image_writer import AsyncImageWriter, write_image
111-
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING, LOSS_TYPE_MAPPING
111+
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING
112112
from opentau.datasets.utils import (
113113
DEFAULT_FEATURES,
114114
DEFAULT_IMAGE_PATH,
@@ -727,9 +727,6 @@ def _to_standard_data_format(self, item: dict) -> dict:
727727
standard_item["img_is_pad"] = torch.tensor(img_is_pad, dtype=torch.bool)
728728
standard_item["action_is_pad"] = item[name_map["actions"] + "_is_pad"]
729729

730-
# add loss type
731-
standard_item["loss_type"] = LOSS_TYPE_MAPPING[self._get_feature_mapping_key()]
732-
733730
# cast all tensors in standard_item to bfloat16
734731
for key, value in standard_item.items():
735732
if isinstance(value, torch.Tensor) and value.dtype.is_floating_point:

src/opentau/datasets/standard_data_format_mapping.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,12 @@
4747
- "prompt": Task descriptions or prompts
4848
- "response": Expected responses or labels
4949
50-
LOSS_TYPE_MAPPING
51-
Dictionary mapping dataset repository IDs to loss type strings. Valid
52-
values are:
53-
54-
- "MSE": Mean Squared Error (typically for continuous robotic actions)
55-
- "CE": Cross Entropy (typically for discrete classification tasks
56-
like VQA)
57-
5850
Example:
5951
Access feature name mapping for a dataset:
6052
>>> mapping = DATA_FEATURES_NAME_MAPPING["lerobot/aloha_mobile_cabinet"]
6153
>>> mapping["camera0"] # Returns "observation.images.cam_right_wrist"
6254
>>> mapping["actions"] # Returns "action"
6355
64-
Access loss type for a dataset:
65-
>>> loss_type = LOSS_TYPE_MAPPING["lerobot/aloha_mobile_cabinet"]
66-
>>> loss_type # Returns "MSE"
6756
"""
6857

6958
DATA_FEATURES_NAME_MAPPING = {
@@ -247,32 +236,3 @@
247236
"response": "response",
248237
},
249238
}
250-
251-
"""
252-
Use "MSE" for mean squared error and "CE" for cross entropy.
253-
Usually robotic data with actions will have an MSE loss while
254-
VQA tasks will have a CE loss.
255-
"""
256-
LOSS_TYPE_MAPPING = {
257-
"ML-GOD/mt-button-press": "MSE",
258-
"ML-GOD/libero_spatial_no_noops_1.0.0_lerobot": "MSE",
259-
"ML-GOD/libero": "MSE",
260-
"physical-intelligence/libero": "MSE",
261-
"danaaubakirova/koch_test": "MSE",
262-
"lerobot/droid_100": "MSE",
263-
"lerobot/aloha_mobile_cabinet": "MSE",
264-
"autox/agibot-sample": "MSE",
265-
"bi-so100-block-manipulation": "MSE",
266-
"cube-on-cylinder": "MSE",
267-
"cylinder-on-cube": "MSE",
268-
"l-shape-on-cross-shape": "MSE",
269-
"lerobot/svla_so101_pickplace": "MSE",
270-
"lerobot/svla_so100_pickplace": "MSE",
271-
"lerobot/svla_so100_stacking": "MSE",
272-
"pixmo": "CE",
273-
"dummy": "CE",
274-
"vsr": "CE",
275-
"clevr": "CE",
276-
"cocoqa": "CE",
277-
"lerobot_dummy": "MSE",
278-
}

tests/configs/test_default.py

Lines changed: 1 addition & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from opentau.configs.default import DatasetConfig, DatasetMixtureConfig
18-
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING, LOSS_TYPE_MAPPING
18+
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING
1919

2020

2121
@pytest.mark.parametrize(
@@ -94,121 +94,9 @@ def setup_method(self):
9494
"""Set up test fixtures before each test method."""
9595
# Store original state of global mappings
9696
self.original_data_mapping = DATA_FEATURES_NAME_MAPPING.copy()
97-
self.original_loss_mapping = LOSS_TYPE_MAPPING.copy()
9897

9998
def teardown_method(self):
10099
"""Clean up after each test method."""
101100
# Restore original state of global mappings
102101
DATA_FEATURES_NAME_MAPPING.clear()
103102
DATA_FEATURES_NAME_MAPPING.update(self.original_data_mapping)
104-
LOSS_TYPE_MAPPING.clear()
105-
LOSS_TYPE_MAPPING.update(self.original_loss_mapping)
106-
107-
@pytest.mark.parametrize(
108-
"data_mapping, loss_mapping, should_raise",
109-
[
110-
(None, None, False), # Both None - valid
111-
({"camera0": "image"}, "MSE", False), # Both provided - valid
112-
(None, "MSE", True), # Only loss_mapping provided - invalid
113-
({"camera0": "image"}, None, True), # Only data_mapping provided - invalid
114-
],
115-
)
116-
def test_data_mapping_validation(self, data_mapping, loss_mapping, should_raise):
117-
"""Test that data_features_name_mapping and loss_type_mapping must be provided together."""
118-
if should_raise:
119-
with pytest.raises(
120-
ValueError,
121-
match="`data_features_name_mapping` and `loss_type_mapping` have to be provided together.",
122-
):
123-
DatasetConfig(
124-
repo_id="test_repo",
125-
data_features_name_mapping=data_mapping,
126-
loss_type_mapping=loss_mapping,
127-
)
128-
else:
129-
# Should not raise an error
130-
DatasetConfig(
131-
repo_id="test_repo", data_features_name_mapping=data_mapping, loss_type_mapping=loss_mapping
132-
)
133-
134-
def test_mapping_addition_to_global_dicts(self):
135-
"""Test that mappings are added to global dictionaries when both are provided."""
136-
test_repo_id = "test_custom_repo"
137-
test_data_mapping = {"camera0": "observation.image", "state": "observation.state"}
138-
test_loss_mapping = "MSE"
139-
140-
# Ensure the repo_id is not already in the mappings
141-
assert test_repo_id not in DATA_FEATURES_NAME_MAPPING
142-
assert test_repo_id not in LOSS_TYPE_MAPPING
143-
144-
# Create DatasetConfig with both mappings
145-
config = DatasetConfig( # noqa: F841
146-
repo_id=test_repo_id,
147-
data_features_name_mapping=test_data_mapping,
148-
loss_type_mapping=test_loss_mapping,
149-
)
150-
151-
# Check that mappings were added to global dictionaries
152-
assert test_repo_id in DATA_FEATURES_NAME_MAPPING
153-
assert test_repo_id in LOSS_TYPE_MAPPING
154-
assert DATA_FEATURES_NAME_MAPPING[test_repo_id] == test_data_mapping
155-
assert LOSS_TYPE_MAPPING[test_repo_id] == test_loss_mapping
156-
157-
def test_mapping_not_added_when_both_none(self):
158-
"""Test that mappings are not added to global dictionaries when both are None."""
159-
test_repo_id = "test_none_repo"
160-
161-
# Ensure the repo_id is not already in the mappings
162-
assert test_repo_id not in DATA_FEATURES_NAME_MAPPING
163-
assert test_repo_id not in LOSS_TYPE_MAPPING
164-
165-
# Create DatasetConfig with both mappings as None
166-
config = DatasetConfig(repo_id=test_repo_id, data_features_name_mapping=None, loss_type_mapping=None) # noqa: F841
167-
168-
# Check that mappings were not added to global dictionaries
169-
assert test_repo_id not in DATA_FEATURES_NAME_MAPPING
170-
assert test_repo_id not in LOSS_TYPE_MAPPING
171-
172-
def test_mapping_overwrites_existing(self):
173-
"""Test that providing mappings overwrites existing entries for the same repo_id."""
174-
test_repo_id = "test_overwrite_repo"
175-
original_data_mapping = {"old": "mapping"}
176-
original_loss_mapping = "CE"
177-
new_data_mapping = {"camera0": "observation.image", "state": "observation.state"}
178-
new_loss_mapping = "MSE"
179-
180-
# Add original mappings
181-
DATA_FEATURES_NAME_MAPPING[test_repo_id] = original_data_mapping
182-
LOSS_TYPE_MAPPING[test_repo_id] = original_loss_mapping
183-
184-
# Create DatasetConfig with new mappings
185-
config = DatasetConfig( # noqa: F841
186-
repo_id=test_repo_id,
187-
data_features_name_mapping=new_data_mapping,
188-
loss_type_mapping=new_loss_mapping,
189-
)
190-
191-
# Check that mappings were overwritten
192-
assert DATA_FEATURES_NAME_MAPPING[test_repo_id] == new_data_mapping
193-
assert LOSS_TYPE_MAPPING[test_repo_id] == new_loss_mapping
194-
assert DATA_FEATURES_NAME_MAPPING[test_repo_id] != original_data_mapping
195-
assert LOSS_TYPE_MAPPING[test_repo_id] != original_loss_mapping
196-
197-
def test_empty_mappings(self):
198-
"""Test behavior with empty mappings."""
199-
test_repo_id = "test_empty_repo"
200-
empty_data_mapping = {}
201-
test_loss_mapping = "MSE"
202-
203-
# Create DatasetConfig with empty data mapping
204-
config = DatasetConfig( # noqa: F841
205-
repo_id=test_repo_id,
206-
data_features_name_mapping=empty_data_mapping,
207-
loss_type_mapping=test_loss_mapping,
208-
)
209-
210-
# Check that empty mapping was added
211-
assert test_repo_id in DATA_FEATURES_NAME_MAPPING
212-
assert test_repo_id in LOSS_TYPE_MAPPING
213-
assert DATA_FEATURES_NAME_MAPPING[test_repo_id] == empty_data_mapping
214-
assert LOSS_TYPE_MAPPING[test_repo_id] == test_loss_mapping

tests/datasets/test_datasets.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,6 @@ def check_standard_data_format(item, delta_timestamps_params, dataset, train_pip
310310
("actions", (train_pipeline_config.action_chunk, train_pipeline_config.max_action_dim)),
311311
("prompt", None),
312312
("response", None),
313-
("loss_type", None),
314313
("img_is_pad", (train_pipeline_config.num_cams,)),
315314
("action_is_pad", (train_pipeline_config.action_chunk,)),
316315
]
@@ -329,7 +328,7 @@ def check_standard_data_format(item, delta_timestamps_params, dataset, train_pip
329328
assert item[key].shape == shape, f"{key}"
330329
elif key == "state" or key == "actions":
331330
assert item[key].shape == shape, f"{key}"
332-
elif key == "prompt" or key == "response" or key == "loss_type":
331+
elif key == "prompt" or key == "response":
333332
assert type(item[key]) is str, f"{key}"
334333
elif key == "img_is_pad" or key == "action_is_pad":
335334
assert item[key].shape == shape, f"{key}"

0 commit comments

Comments
 (0)