diff --git a/openhands-sdk/openhands/sdk/tool/tool.py b/openhands-sdk/openhands/sdk/tool/tool.py index 0079efb3eb..bd44040948 100644 --- a/openhands-sdk/openhands/sdk/tool/tool.py +++ b/openhands-sdk/openhands/sdk/tool/tool.py @@ -1,4 +1,5 @@ import re +import threading from abc import ABC, abstractmethod from collections.abc import Sequence from typing import ( @@ -42,6 +43,7 @@ ObservationT = TypeVar("ObservationT", bound=Observation) _action_types_with_risk: dict[type, type] = {} _action_types_with_summary: dict[type, type] = {} +_action_type_lock = threading.Lock() def _camel_to_snake(name: str) -> str: @@ -477,25 +479,26 @@ def resolve_kind(cls, kind: str) -> type: def create_action_type_with_risk(action_type: type[Schema]) -> type[Schema]: - action_type_with_risk = _action_types_with_risk.get(action_type) - if action_type_with_risk: + with _action_type_lock: + action_type_with_risk = _action_types_with_risk.get(action_type) + if action_type_with_risk: + return action_type_with_risk + + action_type_with_risk = type( + f"{action_type.__name__}WithRisk", + (action_type,), + { + "security_risk": Field( + # We do NOT add default value to make it an required field + # default=risk.SecurityRisk.UNKNOWN + description="The LLM's assessment of the safety risk of this action.", # noqa:E501 + ), + "__annotations__": {"security_risk": risk.SecurityRisk}, + }, + ) + _action_types_with_risk[action_type] = action_type_with_risk return action_type_with_risk - action_type_with_risk = type( - f"{action_type.__name__}WithRisk", - (action_type,), - { - "security_risk": Field( - # We do NOT add default value to make it an required field - # default=risk.SecurityRisk.UNKNOWN - description="The LLM's assessment of the safety risk of this action.", - ), - "__annotations__": {"security_risk": risk.SecurityRisk}, - }, - ) - _action_types_with_risk[action_type] = action_type_with_risk - return action_type_with_risk - def _create_action_type_with_summary(action_type: type[Schema]) -> type[Schema]: """Create a new action type with summary field for LLM to predict. @@ -509,24 +512,25 @@ def _create_action_type_with_summary(action_type: type[Schema]) -> type[Schema]: Returns: A new type that includes the summary field """ - action_type_with_summary = _action_types_with_summary.get(action_type) - if action_type_with_summary: - return action_type_with_summary - - action_type_with_summary = type( - f"{action_type.__name__}WithSummary", - (action_type,), - { - "summary": Field( - default=None, - description=( - "A concise summary (approximately 10 words) describing what " - "this specific action does. Focus on the key operation and target. " - "Example: 'List all Python files in current directory'" + with _action_type_lock: + action_type_with_summary = _action_types_with_summary.get(action_type) + if action_type_with_summary: + return action_type_with_summary + + action_type_with_summary = type( + f"{action_type.__name__}WithSummary", + (action_type,), + { + "summary": Field( + default=None, + description=( + "A concise summary (approximately 10 words) describing what " + "this specific action does. Focus on the key operation and target. " # noqa:E501 + "Example: 'List all Python files in current directory'" + ), ), - ), - "__annotations__": {"summary": str | None}, - }, - ) - _action_types_with_summary[action_type] = action_type_with_summary - return action_type_with_summary + "__annotations__": {"summary": str | None}, + }, + ) + _action_types_with_summary[action_type] = action_type_with_summary + return action_type_with_summary diff --git a/tests/sdk/tool/test_tool.py b/tests/sdk/tool/test_tool.py index 62ec358bab..cd61727088 100644 --- a/tests/sdk/tool/test_tool.py +++ b/tests/sdk/tool/test_tool.py @@ -1,9 +1,26 @@ """Test Tool class functionality.""" +import gc +import threading +from abc import ABC + import pytest -from pydantic import ValidationError +from pydantic import Field, ValidationError +from openhands.sdk.tool import Action from openhands.sdk.tool.spec import Tool +from openhands.sdk.tool.tool import ( + _action_types_with_risk, + _action_types_with_summary, + _create_action_type_with_summary, + create_action_type_with_risk, +) +from openhands.sdk.utils.models import _get_checked_concrete_subclasses + + +# Must live at module scope (Pydantic rejects classes). +class _Bug2199Action(Action, ABC): + cmd: str = Field(description="test") def test_tool_minimal(): @@ -177,3 +194,75 @@ def test_tool_repr(): assert "Tool" in repr_str assert "TerminalTool" in repr_str + + +def test_issue_2199_1(request): + """Reproduce issue #2199: duplicate dynamic Action wrapper classes. + + When subagent threads concurrently call ``create_action_type_with_risk`` + or ``_create_action_type_with_summary`` on the same input, a TOCTOU race + on the module-level cache can create two distinct class objects with the + same ``__name__``, causing ``_get_checked_concrete_subclasses(Action)`` + to raise ``ValueError("Duplicate class definition ...")``. + + Ref: https://github.com/issues/assigned?issue=OpenHands%7Csoftware-agent-sdk%7C2199 + """ + """Many threads wrapping the same type must all get the same class object.""" + saved_risk = dict(_action_types_with_risk) + + def _cleanup(): + _action_types_with_risk.clear() + _action_types_with_risk.update(saved_risk) + gc.collect() + + request.addfinalizer(_cleanup) + + results: list[type] = [] + barrier = threading.Barrier(8) + + def worker(): + barrier.wait() + results.append(create_action_type_with_risk(_Bug2199Action)) + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(set(id(r) for r in results)) == 1, "All threads must get the same class" + _get_checked_concrete_subclasses(Action) + + +def test_issue_2199_2(request): + """ + Same race test for _create_action_type_with_summary. + """ + saved_risk = dict(_action_types_with_risk) + saved_summary = dict(_action_types_with_summary) + + def _cleanup(): + _action_types_with_risk.clear() + _action_types_with_risk.update(saved_risk) + _action_types_with_summary.clear() + _action_types_with_summary.update(saved_summary) + gc.collect() + + request.addfinalizer(_cleanup) + + with_risk = create_action_type_with_risk(_Bug2199Action) + results: list[type] = [] + barrier = threading.Barrier(8) + + def worker(): + barrier.wait() + results.append(_create_action_type_with_summary(with_risk)) + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(set(id(r) for r in results)) == 1, "All threads must get the same class" + _get_checked_concrete_subclasses(Action)