Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions autogpt_platform/backend/backend/blocks/human_in_the_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionStatus
from backend.data.human_review import ReviewResult
Expand Down Expand Up @@ -61,17 +62,18 @@ def __init__(self):
categories={BlockCategory.BASIC},
input_schema=HumanInTheLoopBlock.Input,
output_schema=HumanInTheLoopBlock.Output,
block_type=BlockType.HUMAN_IN_THE_LOOP,
test_input={
"data": {"name": "John Doe", "age": 30},
"name": "User profile data",
"editable": True,
},
test_output=[
("reviewed_data", {"name": "John Doe", "age": 30}),
("status", "approved"),
("review_message", ""),
("reviewed_data", {"name": "John Doe", "age": 30}),
],
test_mock={
"_is_safe_mode": lambda *_args, **_kwargs: True, # Mock safe mode check
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
Expand All @@ -80,9 +82,36 @@ def __init__(self):
node_exec_id="test-node-exec-id",
),
"update_node_execution_status": lambda *_args, **_kwargs: None,
"update_review_processed_status": lambda *_args, **_kwargs: None,
},
)

async def _is_safe_mode(
self, graph_id: str, user_id: str, graph_version: int
) -> bool:
"""Get the safe mode setting for a graph, defaulting to True (safe mode ON)."""
graph = await get_database_manager_async_client().get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if graph and graph.settings:
return graph.settings.safe_mode
return True

async def get_or_create_human_review(self, **kwargs):
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)

async def update_node_execution_status(self, **kwargs):
return await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)

async def update_review_processed_status(self, node_exec_id: str, processed: bool):
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)

async def run(
self,
input_data: Input,
Expand All @@ -99,13 +128,22 @@ async def run(

This method uses one function to handle the complete workflow - checking existing reviews
and creating pending ones as needed.

If safe_mode is disabled, this block will automatically approve the data without requiring human intervention.
"""
try:
logger.debug(f"HITL block executing for node {node_exec_id}")
if not await self._is_safe_mode(graph_id, user_id, graph_version):
logger.info(
f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
)
# Automatically approve the data
yield "status", "approved"
yield "reviewed_data", input_data.data
yield "review_message", "Auto-approved (safe mode disabled)"
return

try:
# Use the data layer to handle the complete workflow
db_client = get_database_manager_async_client()
result = await db_client.get_or_create_human_review(
result = await self.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
Expand All @@ -128,8 +166,7 @@ async def run(
# Set node status to REVIEW so execution manager can't mark it as COMPLETED
# The VALID_STATUS_TRANSITIONS will then prevent any unwanted status changes
# Use the proper wrapper function to ensure websocket events are published
await async_update_node_execution_status(
db_client=db_client,
await self.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
Expand All @@ -144,7 +181,7 @@ async def run(
# Review is complete (approved or rejected) - check if unprocessed
if not result.processed:
# Mark as processed before yielding
await db_client.update_review_processed_status(
await self.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)

Expand Down
51 changes: 36 additions & 15 deletions autogpt_platform/backend/backend/blocks/time_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import UserContext
from backend.data.model import SchemaField
from backend.data.model import USER_TIMEZONE_NOT_SET, SchemaField
from backend.util.clients import get_database_manager_async_client

# Shared timezone literal type for all time/date blocks
TimezoneLiteral = Literal[
Expand Down Expand Up @@ -62,6 +62,13 @@
logger = logging.getLogger(__name__)


async def _get_effective_timezone(user_id: str) -> str:
user = await get_database_manager_async_client().get_user_by_id(user_id)
if user and user.timezone and user.timezone != USER_TIMEZONE_NOT_SET:
return user.timezone
return "UTC"


def _get_timezone(
format_type: Any, # Any format type with timezone and use_user_timezone attributes
user_timezone: str | None,
Expand Down Expand Up @@ -137,6 +144,10 @@ class TimeISO8601Format(BaseModel):


class GetCurrentTimeBlock(Block):
async def _get_user_timezone(self, user_id: str) -> str:
"""Get the effective timezone for a user, defaulting to UTC if not set."""
return await _get_effective_timezone(user_id)

class Input(BlockSchemaInput):
trigger: str = SchemaField(
description="Trigger any data to output the current time"
Expand Down Expand Up @@ -185,13 +196,14 @@ def __init__(self):
lambda t: "T" in t and ("+" in t or "Z" in t),
), # Check for ISO format with timezone
],
test_mock={
"_get_user_timezone": lambda *args, **kwargs: "UTC",
},
)

async def run(
self, input_data: Input, *, user_context: UserContext, **kwargs
) -> BlockOutput:
# Extract timezone from user_context (always present)
effective_timezone = user_context.timezone
async def run(self, input_data: Input, *, user_id: str, **kwargs) -> BlockOutput:
# Get user timezone from database
effective_timezone = await self._get_user_timezone(user_id)

# Get the appropriate timezone
tz = _get_timezone(input_data.format_type, effective_timezone)
Expand Down Expand Up @@ -227,6 +239,9 @@ class DateISO8601Format(BaseModel):


class GetCurrentDateBlock(Block):
async def _get_user_timezone(self, user_id: str) -> str:
return await _get_effective_timezone(user_id)

class Input(BlockSchemaInput):
trigger: str = SchemaField(
description="Trigger any data to output the current date"
Expand Down Expand Up @@ -296,12 +311,13 @@ def __init__(self):
and t[7] == "-", # ISO date format YYYY-MM-DD
),
],
test_mock={
"_get_user_timezone": lambda *args, **kwargs: "UTC",
},
)

async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Extract timezone from user_context (required keyword argument)
user_context: UserContext = kwargs["user_context"]
effective_timezone = user_context.timezone
async def run(self, input_data: Input, *, user_id: str, **kwargs) -> BlockOutput:
effective_timezone = await self._get_user_timezone(user_id)

try:
offset = int(input_data.offset)
Expand Down Expand Up @@ -338,6 +354,10 @@ class ISO8601Format(BaseModel):


class GetCurrentDateAndTimeBlock(Block):
async def _get_user_timezone(self, user_id: str) -> str:
"""Get the effective timezone for a user, defaulting to UTC if not set."""
return await _get_effective_timezone(user_id)

class Input(BlockSchemaInput):
trigger: str = SchemaField(
description="Trigger any data to output the current date and time"
Expand Down Expand Up @@ -402,12 +422,13 @@ def __init__(self):
< timedelta(seconds=10), # 10 seconds error margin for ISO format.
),
],
test_mock={
"_get_user_timezone": lambda *args, **kwargs: "UTC",
},
)

async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Extract timezone from user_context (required keyword argument)
user_context: UserContext = kwargs["user_context"]
effective_timezone = user_context.timezone
async def run(self, input_data: Input, *, user_id: str, **kwargs) -> BlockOutput:
effective_timezone = await self._get_user_timezone(user_id)

# Get the appropriate timezone
tz = _get_timezone(input_data.format_type, effective_timezone)
Expand Down
10 changes: 10 additions & 0 deletions autogpt_platform/backend/backend/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class BlockType(Enum):
AGENT = "Agent"
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"


class BlockCategory(Enum):
Expand Down Expand Up @@ -796,3 +797,12 @@ def get_io_block_ids() -> Sequence[str]:
for id, B in get_blocks().items()
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
]


@cached(ttl_seconds=3600)
def get_human_in_the_loop_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
]
4 changes: 1 addition & 3 deletions autogpt_platform/backend/backend/data/credit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
from backend.data.execution import NodeExecutionEntry, UserContext
from backend.data.execution import NodeExecutionEntry
from backend.data.user import DEFAULT_USER_ID
from backend.executor.utils import block_usage_cost
from backend.integrations.credentials_store import openai_credentials
Expand Down Expand Up @@ -86,7 +86,6 @@ async def test_block_credit_usage(server: SpinTestServer):
"type": openai_credentials.type,
},
},
user_context=UserContext(timezone="UTC"),
),
)
assert spending_amount_1 > 0
Expand All @@ -101,7 +100,6 @@ async def test_block_credit_usage(server: SpinTestServer):
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
user_context=UserContext(timezone="UTC"),
),
)
assert spending_amount_2 == 0
Expand Down
15 changes: 1 addition & 14 deletions autogpt_platform/backend/backend/data/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ def from_db(_graph_exec: AgentGraphExecution):

def to_graph_execution_entry(
self,
user_context: "UserContext",
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
parent_graph_exec_id: Optional[str] = None,
):
Expand All @@ -375,7 +374,6 @@ def to_graph_execution_entry(
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
user_context=user_context,
parent_graph_exec_id=parent_graph_exec_id,
)

Expand Down Expand Up @@ -448,9 +446,7 @@ def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
end_time=_node_exec.endedTime,
)

def to_node_execution_entry(
self, user_context: "UserContext"
) -> "NodeExecutionEntry":
def to_node_execution_entry(self) -> "NodeExecutionEntry":
return NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=self.graph_exec_id,
Expand All @@ -460,7 +456,6 @@ def to_node_execution_entry(
node_id=self.node_id,
block_id=self.block_id,
inputs=self.input_data,
user_context=user_context,
)


Expand Down Expand Up @@ -1099,19 +1094,12 @@ async def get_latest_node_execution(
# ----------------- Execution Infrastructure ----------------- #


class UserContext(BaseModel):
"""Generic user context for graph execution containing user-specific settings."""

timezone: str


class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
user_context: UserContext
parent_graph_exec_id: Optional[str] = None


Expand All @@ -1124,7 +1112,6 @@ class NodeExecutionEntry(BaseModel):
node_id: str
block_id: str
inputs: BlockInput
user_context: UserContext


class ExecutionQueue(Generic[T]):
Expand Down
Loading
Loading