Skip to content
8 changes: 4 additions & 4 deletions python/agents/personalized-shopping/deployment/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import vertexai
from vertexai.preview.reasoning_engines import AdkApp
from vertexai import agent_engines
from dotenv import load_dotenv
import os

import vertexai
from dotenv import load_dotenv
from personalized_shopping.agent import root_agent
from vertexai import agent_engines
from vertexai.preview.reasoning_engines import AdkApp

load_dotenv()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
os.environ["GOOGLE_CLOUD_LOCATION"] = "global"
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "True")

import torch
import torch # noqa: E402

# Workaround to Resolve the PyTorch-Streamlit Incompatibility Issue
torch.classes.__path__ = []

# Initialize webshop environment (requires Java)
# If Java is not available (e.g., in CI), set webshop_env to None
try:
from .shared_libraries.init_env import init_env, webshop_env
from .shared_libraries.init_env import init_env, webshop_env # noqa: E402
except Exception:
webshop_env = None
init_env = None

from . import agent
from . import agent # noqa: F401, E402
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from google.adk.agents import Agent
from google.adk.tools import FunctionTool

from .tools.search import search
from .tools.click import click

from .prompt import personalized_shopping_agent_instruction
from .tools.click import click
from .tools.search import search

root_agent = Agent(
model="gemini-2.5-flash",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@ def get_webshop_env():
if _webshop_env is None:
_webshop_env = init_env(num_product_items)
_webshop_env.reset()
print(f"Finished initializing WebshopEnv with {num_product_items} items.")
print(
f"Finished initializing WebshopEnv with {num_product_items} items."
)
return _webshop_env
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

import json
import sys

from tqdm import tqdm

sys.path.insert(0, "../")

from web_agent_site.engine.engine import load_products
from web_agent_site.engine.engine import load_products # noqa: E402

all_products, *_ = load_products(filepath="../data/items_shuffle.json")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .envs.web_agent_text_env import WebAgentTextEnv
from .envs.web_agent_text_env import WebAgentTextEnv # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this import be moved to the file where WebAgentTextEnv is actually used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Out of scope for the purpose of this PR though.

Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

""" """

from ast import literal_eval
from collections import defaultdict
from decimal import Decimal
import json
import os
import random
import re
from ast import literal_eval
from collections import defaultdict
from decimal import Decimal

from flask import render_template_string
from pyserini.search.lucene import LuceneSearcher
Expand Down Expand Up @@ -179,7 +179,9 @@ def get_top_n_product_from_keywords(
docs = [search_engine.doc(hit.docid) for hit in hits]
top_n_asins = [json.loads(doc.raw())["id"] for doc in docs]
top_n_products = [
product_item_dict[asin] for asin in top_n_asins if asin in product_item_dict
product_item_dict[asin]
for asin in top_n_asins
if asin in product_item_dict
]
return top_n_products

Expand Down Expand Up @@ -334,7 +336,10 @@ def load_products(filepath, num_products=None, human_goals=True):
option_values = []
for option_content in option_contents:
option_value = (
option_content["value"].strip().replace("/", " | ").lower()
option_content["value"]
.strip()
.replace("/", " | ")
.lower()
)
option_image = option_content.get("image", None)

Expand Down Expand Up @@ -364,7 +369,9 @@ def load_products(filepath, num_products=None, human_goals=True):
if asin in human_attributes:
products[i]["instructions"] = human_attributes[asin]
else:
products[i]["instruction_text"] = attributes[asin].get("instruction", None)
products[i]["instruction_text"] = attributes[asin].get(
"instruction", None
)

products[i]["instruction_attributes"] = attributes[asin].get(
"instruction_attributes", None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Functions for specifying goals and reward calculations."""

from collections import defaultdict
import itertools
import random
from rich import print
from collections import defaultdict

import spacy
from rich import print
from thefuzz import fuzz

from .normalize import normalize_color

nlp = spacy.load("en_core_web_sm")
Expand Down Expand Up @@ -53,7 +55,9 @@ def get_human_goals(all_products, product_prices):
price_range = [p for p in PRICE_RANGE if p > price][:4]
if len(price_range) >= 2:
_, price_upper = sorted(random.sample(price_range, 2))
price_text = f", and price lower than {price_upper:.2f} dollars"
price_text = (
f", and price lower than {price_upper:.2f} dollars"
)
else:
price_upper = 1000000
price_text = ""
Expand All @@ -67,7 +71,8 @@ def get_human_goals(all_products, product_prices):
"query": item["query"],
"name": item["name"],
"product_category": item["product_category"],
"instruction_text": product["instruction"].strip(".") + price_text,
"instruction_text": product["instruction"].strip(".")
+ price_text,
"attributes": attributes,
"price_upper": price_upper,
"goal_options": product["instruction_options"],
Expand All @@ -86,7 +91,10 @@ def get_synthetic_goals(all_products, product_prices):
goals = []
cnt_atts = defaultdict(int)
for product in all_products:
if "instruction_text" not in product or product["instruction_text"] is None:
if (
"instruction_text" not in product
or product["instruction_text"] is None
):
continue
product_goals = []
asin = product["asin"]
Expand All @@ -111,14 +119,18 @@ def get_synthetic_goals(all_products, product_prices):
options = product["options"]
option_names = sorted(options)
combinations = list(
itertools.product(*(options[option_name] for option_name in option_names))
itertools.product(
*(options[option_name] for option_name in option_names)
)
)
for combination in combinations:
goal_options = dict()
for i, o in enumerate(combination):
# option_text.append(f'{option_names[i]}: {o}')
goal_options[option_names[i]] = o
option_text = ", and ".join([f"{k}: {v}" for k, v in goal_options.items()])
option_text = ", and ".join(
[f"{k}: {v}" for k, v in goal_options.items()]
)
option_text = " with " + option_text if option_text else ""
product_goals.append(
{
Expand All @@ -138,9 +150,9 @@ def get_synthetic_goals(all_products, product_prices):
cnt_atts[att] += 1
goals += product_goals
for goal in goals:
goal["weight"] = sum(1.0 / cnt_atts[att] for att in goal["attributes"]) / len(
goal["attributes"]
)
goal["weight"] = sum(
1.0 / cnt_atts[att] for att in goal["attributes"]
) / len(goal["attributes"])
return goals


Expand All @@ -152,7 +164,9 @@ def get_type_reward(purchased_product, goal):
purchased_product_category = [
x.strip() for x in purchased_product["product_category"].split("›")
]
goal_product_category = [x.strip() for x in goal["product_category"].split("›")]
goal_product_category = [
x.strip() for x in goal["product_category"].split("›")
]
category_match = (
len(set(purchased_product_category) & set(goal_product_category)) >= 2
)
Expand Down Expand Up @@ -245,15 +259,21 @@ def get_option_reward(purchased_options, goal_options):
break

# Calculate option reward as fraction of goal options hit
r_option = num_option_matches / len(goal_options) if len(goal_options) > 0 else None
r_option = (
num_option_matches / len(goal_options)
if len(goal_options) > 0
else None
)
return r_option, num_option_matches


def get_reward(purchased_product, goal, price, options, **kwargs):
"""Get cumulative reward score for purchased product and goal"""
r_type_dict = get_type_reward(purchased_product, goal)

r_price = (price <= goal["price_upper"]) if goal["price_upper"] > 0 else None
r_price = (
(price <= goal["price_upper"]) if goal["price_upper"] > 0 else None
)

r_att, num_attr_matches = get_attribute_reward(purchased_product, goal)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import json
import random
import string
import time
from collections import defaultdict

import gym
import numpy as np
import torch
from bs4 import BeautifulSoup
from bs4.element import Comment
from flask import Flask
import gym
from gym.envs.registration import register
import numpy as np
import torch

from ..engine.engine import (
ACTION_TO_TEMPLATE,
BACK_TO_SEARCH,
Expand All @@ -45,7 +47,6 @@
random_idx,
)


app = Flask(__name__)


Expand Down Expand Up @@ -124,7 +125,11 @@ def step(self, action):
action_name, action_arg = parse_action(action)
if action_arg is not None:
action_arg = action_arg.lower()
if action_name == "search" and action_arg is not None and action_arg != "":
if (
action_name == "search"
and action_arg is not None
and action_arg != ""
):
status = self.browser.search(action_arg)
elif (
action_name == "click"
Expand Down Expand Up @@ -214,7 +219,9 @@ def observation(self):
elif self.observation_mode == "url":
return self.state["url"]
else:
raise ValueError(f"Observation mode {self.observation_mode} not supported.")
raise ValueError(
f"Observation mode {self.observation_mode} not supported."
)

@property
def state(self):
Expand Down Expand Up @@ -246,13 +253,20 @@ def convert_html_to_text(self, html, simple=False):
processed_t = f"[button] {t} [button_]"
elif t.parent.name == "label": # options
if f'"{t}"' in self.state["url"]:
processed_t = f" [clicked button] {t} [clicked button_]"
processed_t = (
f" [clicked button] {t} [clicked button_]"
)
observation = f"You have clicked {t}.\n" + observation
else:
processed_t = f" [button] {t} [button_]"
elif t.parent.get("class") == ["product-link"]: # product asins
if f"{t}" in self.server.user_sessions[self.session]["asins"]:
processed_t = f"\n[clicked button] {t} [clicked button_]"
if (
f"{t}"
in self.server.user_sessions[self.session]["asins"]
):
processed_t = (
f"\n[clicked button] {t} [clicked button_]"
)
else:
processed_t = f"\n[button] {t} [button_]"
else: # regular, unclickable text
Expand All @@ -273,7 +287,9 @@ def reset(self, session=None, instruction_text=None):
self.session = self.session_prefix + self.session

init_url = f"{self.base_url}/{self.session}"
self.browser.get(init_url, session_id=self.session, session_int=session_int)
self.browser.get(
init_url, session_id=self.session, session_int=session_int
)

self.text_to_clickable = None
self.instruction_text = (
Expand All @@ -295,7 +311,9 @@ def close(self):

def tag_visible(element):
ignore = {"style", "script", "head", "title", "meta", "[document]"}
return element.parent.name not in ignore and not isinstance(element, Comment)
return element.parent.name not in ignore and not isinstance(
element, Comment
)


class SimServer:
Expand Down Expand Up @@ -332,7 +350,9 @@ def __init__(
)
)
self.search_engine = init_search_engine(num_products=num_products)
self.goals = get_goals(self.all_products, self.product_prices, human_goals)
self.goals = get_goals(
self.all_products, self.product_prices, human_goals
)
self.show_attrs = show_attrs

# Fix outcome for random shuffling of goals
Expand All @@ -342,7 +362,9 @@ def __init__(
# Apply `filter_goals` parameter if exists to select speific goal(s)
if filter_goals is not None:
self.goals = [
goal for (i, goal) in enumerate(self.goals) if filter_goals(i, goal)
goal
for (i, goal) in enumerate(self.goals)
if filter_goals(i, goal)
]

# Imposes `limit` on goals via random selection
Expand Down Expand Up @@ -561,7 +583,9 @@ def receive(self, session_id, current_url, session_int=None, **kwargs):
if session_id not in self.user_sessions:
idx = (
session_int
if (session_int is not None and isinstance(session_int, int))
if (
session_int is not None and isinstance(session_int, int)
)
else random_idx(self.cum_weights)
)
goal = self.goals[idx]
Expand Down Expand Up @@ -679,7 +703,9 @@ def __init__(self, server):

def get(self, url, session_id=None, session_int=None):
"""Set browser variables to corresponding link, page HTML for URL"""
self.session_id = url.split("/")[-1] if session_id is None else session_id
self.session_id = (
url.split("/")[-1] if session_id is None else session_id
)
self.page_source, _, _ = self.server.receive(
self.session_id, self.current_url, session_int=session_int
)
Expand Down
Loading