Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ether0/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def filter_problem_types(
columns = (
next(iter(dataset.values())) if isinstance(dataset, DatasetDict) else dataset
).column_names
# molrqa uses 'problem_type'; t[-r1] use 'type'
# ether0-benchmark uses 'problem_type'; some variants may use 'type'
type_col = "problem_type" if "problem_type" in columns else "type"

if any(pt.startswith("re:") for pt in problem_types):
Expand Down
6 changes: 3 additions & 3 deletions src/ether0/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@
(
# On 2/11/2025 James kept seeing on the g3 server cluster:
# > huggingface_hub.errors.HfHubHTTPError: 504 Server Error: Gateway Time-out for
# > url: https://huggingface.co/api/datasets/futurehouse/molrqa/paths-info/abc123
# > url: https://huggingface.co/api/datasets/org/repo/paths-info/abc123
# And on 3/14 James saw this on the g3 server cluster:
# > huggingface_hub.errors.HfHubHTTPError: 502 Server Error: Bad Gateway for
# > url: https://huggingface.co/api/datasets/futurehouse/molrqa2/paths-info/abc123
# > url: https://huggingface.co/api/datasets/org/repo/paths-info/abc123
isinstance(x, HfHubHTTPError)
and x.response.status_code
in {HTTPStatus.BAD_GATEWAY.value, HTTPStatus.GATEWAY_TIMEOUT.value}
)
# On 4/14/2025 James kept seeing on the g5 server cluster:
# > datasets.exceptions.DatasetNotFoundError:
# > Dataset 'futurehouse/molrqa2' doesn't exist on the Hub or cannot be accessed.
# > Dataset 'org/repo' doesn't exist on the Hub or cannot be accessed.
or isinstance(x, DatasetNotFoundError)
)
),
Expand Down
8 changes: 4 additions & 4 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_valid_mol_eval(yhat: str, y: str, expected: float) -> None:
"CCCO",
1.0,
None,
id="molrqa2-train-e54b6a29-143e-5373-a3bd-007b8800d31e",
id="exact-match",
),
pytest.param(
"CCCO",
Expand Down Expand Up @@ -416,15 +416,15 @@ def test_formula_diff(f1: str, f2: str, expected: float) -> None:
"O=C(/C=C/C1=CC=CC=C1)OC[C@H]1O[C@@H](O[C@@H]2O[C@@H]3C[C@H]4[C@H](O)[C@@H](O)[C@@](O)(CO3)[C@@H]24)[C@H](O)[C@@H](O)[C@@H]1O",
None,
1,
id="molrqa2-train-1",
id="passing-1",
),
pytest.param(
"CC(C)C[C@H](NC(=O)[C@H](Cc1c[nH]cn1)NC(=O)[C@H](Cc1ccccc1)NC(=O)OC(C)(C)C)[C@@H](O)[C@@H](O)CC(C)C",
None,
1,
id="molrqa2-train-2",
id="passing-2",
),
pytest.param("CCCCCBr", "CCCCCBr", 1, id="molrqa2-train-3"),
pytest.param("CCCCCBr", "CCCCCBr", 1, id="passing-3"),
],
)
def test_is_reasonable_ring_system(
Expand Down